Skip to content

Commit

Permalink
replace to_affine_quantized_floatx with to_affine_quantized_float8 in…
Browse files Browse the repository at this point in the history
… quantization APIs

ghstack-source-id: f655d60cc7481b5c8db708318b5d6da720a7a0ea
ghstack-comment-id: 2608105249
Pull Request resolved: #1599
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent fa20ed1 commit c9b493f
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 131 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torchao.dtypes
to_nf4
to_affine_quantized_intx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_float8
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
Expand Down
13 changes: 4 additions & 9 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
)
from .float8 import to_affine_quantized_float8
from .floatx import Float8Layout
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Expand All @@ -24,10 +22,7 @@
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .utils import (
Layout,
PlainLayout,
)
from .utils import Layout, PlainLayout

__all__ = [
"NF4Tensor",
Expand All @@ -36,8 +31,8 @@
"to_affine_quantized_intx",
"to_affine_quantized_intx_static",
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"to_affine_quantized_float8",
"to_marlinqqq_quantized_intx",
"Layout",
"PlainLayout",
Expand Down
92 changes: 22 additions & 70 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
Expand All @@ -28,7 +25,6 @@
"AffineQuantizedTensor",
"register_layout",
"to_affine_quantized_intx",
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
Expand Down Expand Up @@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

def __tensor_flatten__(self):
return ["tensor_impl"], [
Expand Down Expand Up @@ -272,7 +256,7 @@ def from_hp_to_intx(
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
return cls(
tensor_impl,
Expand Down Expand Up @@ -417,36 +401,6 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):

to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
54 changes: 54 additions & 0 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
addmm_float8_unwrapped_inference,
preprocess_data,
)
from torchao.quantization.quant_primitives import (
FP8_TYPES,
choose_qparams_affine_float8,
dequantize_affine_float8,
quantize_affine_float8,
)
from torchao.utils import _is_float8_type, fill_defaults

aten = torch.ops.aten
Expand Down Expand Up @@ -209,6 +215,51 @@ def __repr__(self):
)


class Float8Tensor(AffineQuantizedTensor):
"""
Float8 quantized tensor subclass which inherits AffineQuantizedTensor class.
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_float8(
int_data,
scale,
output_dtype=output_dtype,
)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = Float8Layout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)


##########################
# Float8 Dispatch Kernels
##########################
Expand Down Expand Up @@ -311,3 +362,6 @@ def _linear_fp_act_fp8_weight_impl(
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)


to_affine_quantized_float8 = Float8Tensor.from_hp_to_float8
20 changes: 6 additions & 14 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
from torchao.quantization.autoquant import (
AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1,
)
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.quantization.subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1015,12 +1009,11 @@ def get_per_token_block_size(x):
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls
Expand All @@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1058,12 +1051,11 @@ def get_weight_block_size(x):
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls
Expand Down
20 changes: 6 additions & 14 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.quantization.utils import (
compute_error,
quantize_activation_per_token_absmax,
Expand All @@ -34,10 +31,7 @@
is_sm_at_least_90,
)

from .granularity import (
PerRow,
PerTensor,
)
from .granularity import PerRow, PerTensor
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -995,12 +989,11 @@ def get_per_token_block_size(x):
}
block_size = get_weight_block_size(weight)

weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = to_linear_activation_quantized(
weight, input_quant_func, quant_kwargs=input_quant_kwargs
Expand All @@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1043,12 +1036,11 @@ def get_weight_block_size(x):
"activation_dtype": input_target_dtype,
}
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls
Expand Down
Loading

0 comments on commit c9b493f

Please sign in to comment.