Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent 1cbc037 commit 5061252
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
# LICENSE file in the root directory of this source tree.

import math
from enum import Enum, auto
from enum import auto, Enum
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import (
ScalingGranularity,
)
from torchao.float8.float8_utils import (
tensor_to_scale as tensor_to_float8_scale,
)
from torchao.prototype.custom_fp_utils import (
Expand All @@ -22,11 +20,11 @@
_n_ones,
)
from torchao.utils import (
_is_float8_type,
_register_custom_op,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_float8_type,
_register_custom_op,
)

__all__ = [
Expand Down Expand Up @@ -1346,7 +1344,7 @@ def quantize_affine_float8(
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
max_value = torch.finfo(float8_dtype).max
Expand All @@ -1361,13 +1359,16 @@ def dequantize_affine_float8(
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to float32 tensor.
Dequantizes the float8 tensor to high precision tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor / scale
return hp_tensor.to(output_dtype)

0 comments on commit 5061252

Please sign in to comment.