Skip to content

Commit

Permalink
power of 2 scale in amax_to_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 5, 2025
1 parent a9fe17e commit 896bd8f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
4 changes: 1 addition & 3 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ def hp_tensor_to_float8_dynamic(
device_mesh,
scaling_granularity,
axiswise_dim,
power_of_2_scale,
)
if power_of_2_scale:
# rounds down to the nearest power of 2.
scale = torch.exp2(torch.floor(torch.log2(scale)))
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down
18 changes: 10 additions & 8 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import torch.distributed as dist
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce

from torchao.float8.config import (
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -33,11 +29,14 @@


@torch.no_grad()
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, power_of_2_scale: bool = False
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
Expand All @@ -46,7 +45,9 @@ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

if power_of_2_scale:
# rounds down to the nearest power of 2.
res = torch.exp2(torch.floor(torch.log2(res)))
return res.to(torch.float32)


Expand Down Expand Up @@ -125,6 +126,7 @@ def tensor_to_scale(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
) -> torch.Tensor:
amax = tensor_to_amax(
x,
Expand All @@ -133,7 +135,7 @@ def tensor_to_scale(
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype)
return amax_to_scale(amax, float8_dtype, power_of_2_scale=power_of_2_scale)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down

0 comments on commit 896bd8f

Please sign in to comment.