Skip to content

Commit

Permalink
config migration: fpx, gemlite, uintx
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d8152a4e6d0c254652a2d5dac4fcb82d04bb5630
ghstack-comment-id: 2649778077
Pull Request resolved: #1697
  • Loading branch information
vkuzo committed Feb 11, 2025
1 parent ebb7fe3 commit 75012dc
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 76 deletions.
6 changes: 3 additions & 3 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


Expand All @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear = torch.compile(linear)
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))

Expand Down Expand Up @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype):
)
bf16_size = get_model_size_in_bytes(linear)
# make sure it runs
uintx_weight_only(dtype)(linear[0])
quantize_(linear[0], uintx_weight_only(dtype))
quantized_size = get_model_size_in_bytes(linear)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
8 changes: 3 additions & 5 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def _eval_hqq(dtype):
dummy_linear.weight.data = W
if dtype == torch.uint4:
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
)(dummy_linear).weight
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight

quant_linear_layer = torch.nn.Linear(
W.shape[1], W.shape[0], bias=False, device=W.device
Expand Down
23 changes: 21 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Expand All @@ -55,6 +58,13 @@
unwrap_tensor_subclass,
)

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True).module()
Expand Down Expand Up @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim):
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int4_weight(),
int8_weight_only(),
fpx_weight_only(ebits=4, mbits=3),
gemlite_uintx_weight_only(),
uintx_weight_only(dtype=torch.uint4),
],
)
def test_workflow_e2e_numerics(self, config):
Expand All @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config):
and is_sm_at_least_90()
):
return unittest.skip("only supported on CUDA capability 8.9, not greater")
elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite:
return unittest.skip("gemlite not available")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

dtype = torch.bfloat16
if isinstance(config, gemlite_uintx_weight_only):
dtype = torch.float16

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
x = torch.randn(128, 128, device="cuda", dtype=dtype)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
m_q = copy.deepcopy(m_ref)

# quantize
Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntxWeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
Expand Down Expand Up @@ -135,6 +138,9 @@
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
"UIntxWeightOnlyConfig",
"FPXWeightOnlyConfig",
"GemliteUIntXWeightOnlyConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
189 changes: 123 additions & 66 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
return module


def gemlite_uintx_weight_only(
group_size: Optional[int] = 64,
bit_width: int = 4,
packing_bitwidth: int = 32,
contiguous: Optional[bool] = None,
):
@dataclass
class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
"""
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
Expand All @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
"""

group_size: Optional[int] = 64
bit_width: int = 4
packing_bitwidth: int = 32
contiguous: Optional[bool] = None


# for BC
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig


@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
def _gemlite_uintx_weight_only_transform(
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
):
group_size = config.group_size
bit_width = config.bit_width
packing_bitwidth = config.packing_bitwidth
contiguous = config.contiguous

weight = module.weight

from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs

use_hqq = True if bit_width == 4 else False
apply_fn = lambda weight: to_affine_quantized_intx(
new_weight = to_affine_quantized_intx(
weight,
**get_gemlite_aqt_kwargs(
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
),
)
return _get_linear_subclass_inserter(apply_fn)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
Expand Down Expand Up @@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform(
return module


def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
@dataclass
class UIntXWeightOnlyConfig(AOBaseConfig):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by `dtype`
Args:
Expand All @@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
`pack_dim`: the dimension we use for packing, defaults to -1
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
"""

dtype: torch.dtype
group_size: int = 64
pack_dim: int = -1
use_hqq: bool = False


# for BC
uintx_weight_only = UIntXWeightOnlyConfig


@register_quantize_module_handler(UIntXWeightOnlyConfig)
def _uintx_weight_only_transform(
module: torch.nn.Module, config: UIntXWeightOnlyConfig
):
dtype = config.dtype
group_size = config.group_size
pack_dim = config.pack_dim
use_hqq = config.use_hqq

weight = module.weight

from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS

SUPPORTED_DTYPES = {
Expand All @@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
}
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"

def apply_uintx_weight_only_quant(weight, dtype):
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)

if use_hqq:
if dtype == torch.uint4:
logger.warn(
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
)
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
_layout = PlainLayout()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)

return to_affine_quantized_intx(
weight,
mapping_type,
block_size,
dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
_layout=_layout,
use_hqq=use_hqq,
)
if use_hqq:
if dtype == torch.uint4:
logger.warn(
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
)
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
_layout = PlainLayout()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)

return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
_layout=_layout,
use_hqq=use_hqq,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def fpx_weight_only(ebits: int, mbits: int):
@dataclass
class FPXWeightOnlyConfig(AOBaseConfig):
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
e.g. fp6_e3_m2, fp6_e2_m3, ...
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
Expand All @@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int):
in the future
"""

def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreLayout
ebits: int
mbits: int

assert (
weight.dim() == 2
), f"floatx only works for 2-d Tensor, got: {weight.dim()}"
out_dim, in_dim = weight.shape
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
logger.info(
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
)
return weight

_layout = FloatxTensorCoreLayout(ebits, mbits)
return to_affine_quantized_fpx(weight, _layout)
# for BC
fpx_weight_only = FPXWeightOnlyConfig


@register_quantize_module_handler(FPXWeightOnlyConfig)
def _fpx_weight_only_transform(
module: torch.nn.Module, config: FPXWeightOnlyConfig
) -> torch.nn.Module:
ebits = config.ebits
mbits = config.mbits
weight = module.weight

from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreLayout

return _get_linear_subclass_inserter(apply_quant_llm)
assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}"
out_dim, in_dim = weight.shape
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
logger.info(
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
)
return module

_layout = FloatxTensorCoreLayout(ebits, mbits)
new_weight = to_affine_quantized_fpx(weight, _layout)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


if TORCH_VERSION_AT_LEAST_2_5:
Expand Down

0 comments on commit 75012dc

Please sign in to comment.