Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
2 parents 76c7bde + 5061252 commit e116ed8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
6 changes: 3 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@

from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
from torchao.quantization.quant_primitives import (
FP8_TYPES,
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
FP8_TYPES,
MappingType,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
ZeroPointDomain,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor

Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ def quantize_affine_float8(
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
max_value = torch.finfo(float8_dtype).max
Expand All @@ -1361,13 +1361,16 @@ def dequantize_affine_float8(
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to float32 tensor.
Dequantizes the float8 tensor to high precision tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor / scale
return hp_tensor.to(output_dtype)

0 comments on commit e116ed8

Please sign in to comment.