From ac3dc8d12072d187487d80e766852d75c0fe0a5e Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:15:41 -0800 Subject: [PATCH] add separate quantization primitives for float8 ghstack-source-id: ecc4358ee5c4c337a0c567ca4fdde3f0570ec060 ghstack-comment-id: 2608048970 Pull Request resolved: https://github.com/pytorch/ao/pull/1597 --- torchao/quantization/quant_primitives.py | 78 +++++++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..bacc8b2f6c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,22 +5,26 @@ # LICENSE file in the root directory of this source tree. import math -from enum import Enum, auto +from enum import auto, Enum from typing import Callable, Dict, List, Optional, Tuple, Union import torch +from torchao.float8.float8_utils import ( + ScalingGranularity, + tensor_to_scale as tensor_to_float8_scale, +) from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones, ) from torchao.utils import ( + _is_float8_type, + _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - _is_float8_type, - _register_custom_op, ) __all__ = [ @@ -39,6 +43,9 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "choose_qparams_affine_float8", + "quantize_affine_float8", + "dequantize_affine_float8", ] @@ -1300,3 +1307,68 @@ def dequantize_affine_floatx( tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor + + +def choose_qparams_affine_float8( + tensor: torch.Tensor, float8_dtype: torch.dtype +) -> torch.Tensor: + """ + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now: + # https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416 + scale = tensor_to_float8_scale( + tensor, + float8_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, + ) + return scale + + +def quantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + scale (torch.Tensor): Scaling factor for the quantization. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) * scale + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + +def dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to high precision tensor. + + Args: + tensor (torch.Tensor): Input float8 tensor to be dequantized. + scale (torch.Tensor): Scaling factor for the dequantization. + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + fp8_tensor = tensor.to(torch.float32) + hp_tensor = fp8_tensor / scale + return hp_tensor.to(output_dtype)