|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from helpers import random_tensor |
| 4 | + |
| 5 | + |
| 6 | +@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"]) |
| 7 | +@pytest.mark.parametrize("src_dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) |
| 8 | +@pytest.mark.parametrize("dst_dtype", [torch.int8, torch.float8_e4m3fn], ids=["int8", "float8"]) |
| 9 | +@pytest.mark.parametrize("per_axis", [True, False], ids=["per-axis", "per-tensor"]) |
| 10 | +def test_quantize_symmetric(shape, src_dtype, dst_dtype, per_axis, device): |
| 11 | + if device.type == "mps" and dst_dtype != torch.int8: |
| 12 | + pytest.skip("float8 types are not supported on MPS device") |
| 13 | + # Craft manually data and scale |
| 14 | + if dst_dtype.is_floating_point: |
| 15 | + data = random_tensor(shape, torch.float16).to(dst_dtype).to(device) |
| 16 | + else: |
| 17 | + data = torch.randint(-127, 127, shape, dtype=dst_dtype).to(device) |
| 18 | + if per_axis: |
| 19 | + scale_shape = (shape[0],) + (1,) * (len(shape) - 1) |
| 20 | + else: |
| 21 | + scale_shape = () |
| 22 | + scale = torch.rand(scale_shape, dtype=src_dtype).to(device) |
| 23 | + # Dequantize to obtain a float tensor |
| 24 | + t = data.to(src_dtype) * scale |
| 25 | + qdata = torch.ops.quanto.quantize_symmetric(t, scale, dst_dtype) |
| 26 | + assert qdata.dtype == dst_dtype |
| 27 | + assert qdata.shape == shape |
| 28 | + # float8 tensors direct comparison is not supported yet on CPU |
| 29 | + if dst_dtype.is_floating_point: |
| 30 | + assert torch.equal(qdata.to(torch.float16), data.to(torch.float16)) |
| 31 | + else: |
| 32 | + assert torch.equal(qdata, data) |
0 commit comments