Skip to content

Commit

Permalink
add separate quantization primitives for float8
Browse files Browse the repository at this point in the history
ghstack-source-id: ecc4358ee5c4c337a0c567ca4fdde3f0570ec060
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent 32d9b0b commit ac3dc8d
Showing 1 changed file with 75 additions and 3 deletions.
78 changes: 75 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
# LICENSE file in the root directory of this source tree.

import math
from enum import Enum, auto
from enum import auto, Enum
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import (
ScalingGranularity,
tensor_to_scale as tensor_to_float8_scale,
)
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
_n_ones,
)
from torchao.utils import (
_is_float8_type,
_register_custom_op,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_float8_type,
_register_custom_op,
)

__all__ = [
Expand All @@ -39,6 +43,9 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
]


Expand Down Expand Up @@ -1300,3 +1307,68 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
tensor: torch.Tensor, float8_dtype: torch.dtype
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
scale = tensor_to_float8_scale(
tensor,
float8_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=1,
)
return scale


def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor


def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor / scale
return hp_tensor.to(output_dtype)

0 comments on commit ac3dc8d

Please sign in to comment.