From 7cc4944a64920363074169fe528db090116a6191 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:15:41 -0800 Subject: [PATCH] integrate new float8 quantization primitives into AQT ghstack-source-id: c1deeeb84bdbb109a245b7e6c84150a724a012e7 ghstack-comment-id: 2608090492 Pull Request resolved: https://github.com/pytorch/ao/pull/1598 --- torchao/dtypes/affine_quantized_tensor.py | 47 ++++++++++++++++++----- torchao/quantization/quant_primitives.py | 8 ++-- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..d38c7181c1 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -4,27 +4,22 @@ import torch -from torchao.dtypes.utils import ( - AQTTensorImpl, - Layout, - PlainLayout, -) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout from torchao.quantization.quant_primitives import ( FP8_TYPES, MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_float8, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, dequantize_affine_floatx, quantize_affine, + quantize_affine_float8, quantize_affine_floatx, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor logger = logging.getLogger(__name__) aten = torch.ops.aten @@ -422,6 +417,39 @@ def from_hp_to_fpx( tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + @classmethod + def from_hp_to_float8( + cls, + input_float: torch.Tensor, + target_dtype: torch.dtype, + block_size: Tuple[int, ...], + _layout: Layout = PlainLayout(), + ): + assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" + + # to avoid circular dependency + from torchao.dtypes.floatx import Float8AQTTensorImpl + + original_shape = input_float.shape + scale = choose_qparams_affine_float8( + input_float, + target_dtype, + target_dtype, + ) + fp8_data = quantize_affine_float8( + input_float, + scale, + target_dtype, + ) + fp8_data = _layout.post_process(fp8_data) + tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + dtype=input_float.dtype, + ) + @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static +to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bacc8b2f6c..fd6acbe994 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,13 +5,15 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torchao.float8.float8_utils import ( ScalingGranularity, +) +from torchao.float8.float8_utils import ( tensor_to_scale as tensor_to_float8_scale, ) from torchao.prototype.custom_fp_utils import ( @@ -20,11 +22,11 @@ _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [