Skip to content

Commit

Permalink
use bitshifting for power of 2 rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 6, 2025
1 parent 34cc033 commit ab93e18
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)


# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly
def hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
Expand Down
9 changes: 6 additions & 3 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def amax_to_scale(
amax = amax.to(torch.float64)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
res = res.to(torch.float32)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
if round_scales_to_power_of_2:
# rounds down to the nearest power of 2.
res = torch.exp2(torch.floor(torch.log2(res)))
return res.to(torch.float32)
# rounds down to the nearest power of 2
res = res.view(torch.int32)
res = (res >> 23) << 23
res = res.view(torch.float32)
return res


@torch.no_grad()
Expand Down

0 comments on commit ab93e18

Please sign in to comment.