Skip to content

Commit 89e365a

Browse files
committed
feat(library): add quantize_symmetric op
1 parent d443109 commit 89e365a

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

quanto/library/ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@ def impl(*args, **kwargs):
5353
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)
5454

5555

56+
define("quantize_symmetric", "(Tensor self, Tensor scale, ScalarType dtype) -> Tensor")
5657
define("unpack", "(Tensor self, int bits) -> Tensor")

quanto/library/python/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .quantize import *
12
from .unpack import *

quanto/library/python/quantize.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
4+
def dtype_info(dtype):
5+
info = torch.finfo if dtype.is_floating_point else torch.iinfo
6+
return info(dtype)
7+
8+
9+
@torch.library.impl("quanto_py::quantize_symmetric", "default")
10+
def quantize_symmetric(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
11+
info = dtype_info(dtype)
12+
data = t / scale
13+
if not dtype.is_floating_point:
14+
data = torch.round(data)
15+
return torch.clamp(data, min=info.min, max=info.max).to(dtype)

test/library/test_quantize.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import torch
3+
from helpers import random_tensor
4+
5+
6+
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
7+
@pytest.mark.parametrize("src_dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
8+
@pytest.mark.parametrize("dst_dtype", [torch.int8, torch.float8_e4m3fn], ids=["int8", "float8"])
9+
@pytest.mark.parametrize("per_axis", [True, False], ids=["per-axis", "per-tensor"])
10+
def test_quantize_symmetric(shape, src_dtype, dst_dtype, per_axis, device):
11+
if device.type == "mps" and dst_dtype != torch.int8:
12+
pytest.skip("float8 types are not supported on MPS device")
13+
# Craft manually data and scale
14+
if dst_dtype.is_floating_point:
15+
data = random_tensor(shape, torch.float16).to(dst_dtype).to(device)
16+
else:
17+
data = torch.randint(-127, 127, shape, dtype=dst_dtype).to(device)
18+
if per_axis:
19+
scale_shape = (shape[0],) + (1,) * (len(shape) - 1)
20+
else:
21+
scale_shape = ()
22+
scale = torch.rand(scale_shape, dtype=src_dtype).to(device)
23+
# Dequantize to obtain a float tensor
24+
t = data.to(src_dtype) * scale
25+
qdata = torch.ops.quanto.quantize_symmetric(t, scale, dst_dtype)
26+
assert qdata.dtype == dst_dtype
27+
assert qdata.shape == shape
28+
# float8 tensors direct comparison is not supported yet on CPU
29+
if dst_dtype.is_floating_point:
30+
assert torch.equal(qdata.to(torch.float16), data.to(torch.float16))
31+
else:
32+
assert torch.equal(qdata, data)

0 commit comments

Comments
 (0)