diff --git a/test/float8/test_utils.py b/test/float8/test_utils.py new file mode 100644 index 0000000000..34b07d502e --- /dev/null +++ b/test/float8/test_utils.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from torchao.float8.float8_utils import _round_down_to_power_of_2 + + +@pytest.mark.parametrize( + "input_shape", + [ + (1,), + (2, 3), + (8, 2048, 4, 1024), + ], +) +@pytest.mark.parametrize( + "multiplier", + [ + 1.0, + 2.5, + 10.0, + ], +) +def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int): + input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier + expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) + result = _round_down_to_power_of_2(input_tensor) + assert torch.equal( + result, expected_output + ), f"expected {expected_output}, but got {result}" + + +def test_non_float32_input(): + non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) + with pytest.raises(AssertionError, match="input must be float32 tensor"): + _round_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index bdb08bbb01..a7002516b8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -49,10 +49,7 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") if round_scales_to_power_of_2: - # rounds down to the nearest power of 2 - res = res.view(torch.int32) - res = (res >> 23) << 23 - res = res.view(torch.float32) + res = _round_down_to_power_of_2(res) return res @@ -286,3 +283,12 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC ) + + +def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float32, "input must be float32 tensor" + # rounds down to the nearest power of 2 + x = x.view(torch.int32) + x = (x >> 23) << 23 + x = x.view(torch.float32) + return x