diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..77616c1c6a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -9,16 +9,21 @@ import unittest import torch +from parameterized import parameterized from torchao.dtypes.utils import is_device +from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_float8, dequantize_affine, + dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, + quantize_affine_float8, ) # TODO: remove test for utils? @@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + @parameterized.expand( + [ + ( + torch.float32, + torch.float8_e4m3fn, + ), + ( + torch.float32, + torch.float8_e5m2, + ), + ( + torch.bfloat16, + torch.float8_e4m3fn, + ), + ( + torch.bfloat16, + torch.float8_e5m2, + ), + ] + ) + def test_float8_quant_primitives(self, hp_dtype, float8_dtype): + input = torch.randn(10, 10) + + # float8 quantization primitives + scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype) + quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype) + dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype) + + # reference implementation using generic primitives + expected_scale, _ = choose_qparams_affine( + input, + MappingType.SYMMETRIC, + input.shape, + float8_dtype, + eps=float8_eps, # use same EPS as float8 training + scale_dtype=torch.float32, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + ) + expected_quantized = quantize_affine( + input, + input.shape, + scale, + output_dtype=float8_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + expected_dequantized = dequantize_affine( + expected_quantized, + input.shape, + scale, + input_dtype=float8_dtype, + output_dtype=hp_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + + self.assertTrue(torch.equal(expected_scale, scale)) + torch.testing.assert_close(expected_quantized, quantized) + torch.testing.assert_close(expected_dequantized, dequantized) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..8b0ce28434 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -39,6 +39,9 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "choose_qparams_affine_float8", + "quantize_affine_float8", + "dequantize_affine_float8", ] @@ -1300,3 +1303,67 @@ def dequantize_affine_floatx( tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor + + +def choose_qparams_affine_float8( + tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # only tensorwise scaling is supported for now: + quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max + min_val_neg = torch.min(tensor) + max_val_pos = torch.max(tensor) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + return scale.to(dtype=torch.float32) + + +def quantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + scale (torch.Tensor): Scaling factor for the quantization. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) / scale + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + +def dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to high precision tensor. + + Args: + tensor (torch.Tensor): Input float8 tensor to be dequantized. + scale (torch.Tensor): Scaling factor for the dequantization. + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + fp8_tensor = tensor.to(torch.float32) + hp_tensor = fp8_tensor * scale + return hp_tensor.to(output_dtype)