Skip to content

Commit 1606f4a

Browse files
committed
feat(quanto): introduce qtype
This implies a lot of modifications but is functionally equivalent.
1 parent 1b66ea9 commit 1606f4a

25 files changed

+285
-389
lines changed

bench/generation/benchmark.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm.auto import tqdm
88
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
99

10-
from quanto import Calibration, freeze, quantize
10+
from quanto import Calibration, freeze, qint8, quantize
1111

1212

1313
CALIBRATION_PROMPT = "It was a bright cold day in April, and the clocks were striking thirteen."
@@ -165,8 +165,8 @@ def main():
165165
if args.quantization in ("w8a8", "w8a16"):
166166
print("quantizing")
167167
start = time.time()
168-
weights = torch.int8
169-
activations = None if "a16" in args.quantization else torch.int8
168+
weights = qint8
169+
activations = None if "a16" in args.quantization else qint8
170170
quantize(model, weights=weights, activations=activations)
171171
if activations is not None:
172172
print("Calibrating")

examples/nlp/text-classification/sst2/quantize_sst2_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
99
from transformers.pipelines.pt_utils import KeyDataset
1010

11-
from quanto import Calibration, freeze, quantize
11+
from quanto import Calibration, freeze, qint8, quantize
1212

1313

1414
def evaluate_model(model, tokenizer, dataset, device, batch_size):
@@ -22,7 +22,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size):
2222

2323

2424
def keyword_to_itype(k):
25-
return {"none": None, "int8": torch.int8}[k]
25+
return {"none": None, "int8": qint8}[k]
2626

2727

2828
def main():

examples/nlp/text-generation/quantize_causal_lm_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datasets import load_dataset
66
from transformers import AutoModelForCausalLM, AutoTokenizer
77

8-
from quanto import Calibration, freeze, quantize
8+
from quanto import Calibration, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize
99

1010

1111
@torch.no_grad()
@@ -51,7 +51,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size, samples=None,
5151

5252

5353
def keyword_to_itype(k):
54-
return {"none": None, "int8": torch.int8, "fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn}[k]
54+
return {"none": None, "int8": qint8, "fp8_e5m2": qfloat8_e5m2, "fp8_e4m3": qfloat8_e4m3fn}[k]
5555

5656

5757
def main():

examples/vision/image-classification/mnist/quantize_mnist_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision import datasets, transforms
88
from transformers import AutoModel
99

10-
from quanto import Calibration, QTensor, freeze, int4, quantize
10+
from quanto import Calibration, QTensor, freeze, qint4, qint8, quantize
1111

1212

1313
def test(model, device, test_loader):
@@ -60,7 +60,7 @@ def train(log_interval, model, device, train_loader, optimizer, epoch):
6060

6161

6262
def keyword_to_itype(k):
63-
return {"none": None, "int4": int4, "int8": torch.int8}[k]
63+
return {"none": None, "int4": qint4, "int8": qint8}[k]
6464

6565

6666
def main():

quanto/nn/qlinear.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from ..tensor import QBitsTensor, QTensor, absmax_scale, qbitsdtype
5+
from ..tensor import QBitsTensor, QTensor, absmax_scale, qint2, qint4, qint8, qtype
66
from .qmodule import QModuleMixin, register_qmodule
77

88

@@ -11,12 +11,12 @@
1111

1212
@register_qmodule(torch.nn.Linear)
1313
class QLinear(QModuleMixin, torch.nn.Linear):
14-
def __init__(self, *args, weights: torch.dtype = torch.int8, **kwargs):
14+
def __init__(self, *args, weights: qtype = qint8, **kwargs):
1515
super().__init__(*args, **kwargs)
1616
self.weights = weights
1717

1818
@classmethod
19-
def from_module(cls, module, weights=torch.int8, activations: Optional[torch.dtype] = None):
19+
def from_module(cls, module, weights=qint8, activations: Optional[qtype] = None):
2020
qmodule = cls(
2121
module.in_features,
2222
module.out_features,
@@ -36,12 +36,12 @@ def qweight(self):
3636
if isinstance(self.weight, QTensor):
3737
return self.weight
3838
# Quantize the weights per-axis
39-
if isinstance(self.weights, torch.dtype):
39+
if self.weights == qint8:
4040
wscale = absmax_scale(self.weight, axis=0)
41-
return QTensor.quantize(self.weight, itype=self.weights, scale=wscale)
42-
elif isinstance(self.weights, qbitsdtype):
43-
return QBitsTensor.quantize(self.weight, itype=self.weights, axis=0)
44-
raise ValueError("Invalid quantized weights type")
41+
return QTensor.quantize(self.weight, qtype=self.weights, scale=wscale)
42+
elif self.weights in (qint2, qint4):
43+
return QBitsTensor.quantize(self.weight, qtype=self.weights, axis=0)
44+
raise ValueError(f"Invalid quantized weights type {self.weights}")
4545

4646
def qforward(self, input: torch.Tensor) -> torch.Tensor:
4747
if self.activations is not None and not isinstance(input, QTensor):

quanto/nn/qmodule.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def qforward(self, input: torch.Tensor) -> torch.Tensor:
9494

9595
def forward(self, input: torch.Tensor) -> torch.Tensor:
9696
def maybe_requantize(t, scale):
97-
if t.itype == self.activations and t.axis is None:
97+
if t.qtype == self.activations and t.axis is None:
9898
return t
99-
return QTensor.quantize(t.dequantize(), itype=self.activations, scale=scale)
99+
return QTensor.quantize(t.dequantize(), qtype=self.activations, scale=scale)
100100

101101
if self.activations is not None and isinstance(input, QTensor):
102102
input = maybe_requantize(input, self.input_scale)

quanto/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .core import *
2+
from .qtype import *

0 commit comments

Comments
 (0)