Skip to content

Commit

Permalink
integrate new float8 quantization primitives into AQT
Browse files Browse the repository at this point in the history
ghstack-source-id: c1deeeb84bdbb109a245b7e6c84150a724a012e7
ghstack-comment-id: 2608090492
Pull Request resolved: #1598
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent ac3dc8d commit 7cc4944
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
47 changes: 38 additions & 9 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@

import torch

from torchao.dtypes.utils import (
AQTTensorImpl,
Layout,
PlainLayout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor

logger = logging.getLogger(__name__)
aten = torch.ops.aten
Expand Down Expand Up @@ -422,6 +417,39 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"

# to avoid circular dependency
from torchao.dtypes.floatx import Float8AQTTensorImpl

original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn):
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
# LICENSE file in the root directory of this source tree.

import math
from enum import auto, Enum
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import (
ScalingGranularity,
)
from torchao.float8.float8_utils import (
tensor_to_scale as tensor_to_float8_scale,
)
from torchao.prototype.custom_fp_utils import (
Expand All @@ -20,11 +22,11 @@
_n_ones,
)
from torchao.utils import (
_is_float8_type,
_register_custom_op,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_float8_type,
_register_custom_op,
)

__all__ = [
Expand Down

0 comments on commit 7cc4944

Please sign in to comment.