Skip to content

Commit 5742998

Browse files
authored
Rectify Asym Compression/Decompression Pathways (#225)
* fix asym * fix * update tests * fix * update * docstring, comments, typing
1 parent 914d4dd commit 5742998

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
2323

2424
import compressed_tensors
2525
import torch
@@ -522,10 +522,13 @@ def _replace_weights(self, dense_weight_generator, model: Module):
522522
update_parameter_data(module, data, param_name)
523523

524524

525-
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
525+
def map_modules_to_quant_args(
526+
model: Module,
527+
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
526528
"""
527529
Given a pytorch model, map out the submodule name (usually linear layers)
528-
to the QuantizationArgs
530+
to the weight QuantizationArgs. If running input activation quantization, will also
531+
map to the input QuantizationArgs in a tuple.
529532
530533
:param model: pytorch model
531534
"""
@@ -535,6 +538,12 @@ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
535538
if submodule.quantization_scheme.weights is not None:
536539
name = fix_fsdp_module_name(name)
537540
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
541+
if submodule.quantization_scheme.input_activations is not None:
542+
weight_args = quantized_modules_to_args.get(name)
543+
quantized_modules_to_args[name] = (
544+
weight_args,
545+
submodule.quantization_scheme.input_activations,
546+
)
538547

539548
return quantized_modules_to_args
540549

src/compressed_tensors/compressors/quantized_compressors/base.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,44 @@ def compress(
8282
"""
8383
compressed_dict = {}
8484
weight_suffix = ".weight"
85+
input_zp_suffix = ".input_zero_point"
86+
weight_zp_suffix = ".weight_zero_point"
8587
_LOGGER.debug(
8688
f"Compressing model with {len(model_state)} parameterized layers..."
8789
)
8890

8991
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92+
# check if the parameter we're compressing is the weight zp
93+
# or the input zp
94+
is_weight_zp = name.endswith(weight_zp_suffix)
95+
is_input_zp = name.endswith(input_zp_suffix)
96+
97+
# if we're saving the weight zp, fetch weight quant args
98+
if is_weight_zp:
99+
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100+
if isinstance(quant_args_zp, tuple):
101+
# If tuple, first value is weight args, second is input args
102+
quant_args_zp = quant_args_zp[0]
103+
104+
# if we're saving the input zp, fetch input quant args
105+
if is_input_zp:
106+
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107+
if isinstance(input_args_zp, tuple):
108+
# If tuple, first value is weight args, second is input args
109+
input_args_zp = input_args_zp[-1]
110+
90111
if name.endswith(weight_suffix):
91112
prefix = name[: -(len(weight_suffix))]
92113
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
93114
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
94115
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
95116
if scale is not None:
96117
# weight is quantized, compress it
97-
quant_args = names_to_scheme[prefix]
118+
if isinstance(names_to_scheme[prefix], tuple):
119+
quant_args = names_to_scheme[prefix][0]
120+
else:
121+
quant_args = names_to_scheme[prefix]
122+
98123
compressed_data = self.compress_weight(
99124
weight=value,
100125
scale=scale,
@@ -107,7 +132,11 @@ def compress(
107132
compressed_dict[merge_names(prefix, key)] = value
108133
else:
109134
compressed_dict[name] = value.to("cpu")
110-
elif name.endswith("zero_point") and torch.all(value == 0):
135+
# only save if asym
136+
elif is_weight_zp and quant_args_zp.symmetric:
137+
continue
138+
# only save if asym
139+
elif is_input_zp and input_args_zp.symmetric:
111140
continue
112141
elif name.endswith("g_idx") and torch.any(value <= -1):
113142
continue

tests/test_compressors/quantized_compressors/test_int_quant.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
from safetensors.torch import save_file
2828

2929

30-
def get_dummy_quant_config(strategy, group_size=None):
30+
def get_dummy_quant_config(strategy, group_size=None, symmetric=True):
3131
config_groups = {
3232
"group_1": QuantizationScheme(
3333
targets=["Linear"],
34-
weights=QuantizationArgs(strategy=strategy, group_size=group_size),
34+
weights=QuantizationArgs(
35+
strategy=strategy, group_size=group_size, symmetric=symmetric
36+
),
3537
),
3638
}
3739
ignore = ["lm_head"]
@@ -69,7 +71,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
6971
"dummy.weight_scale": torch.tensor(sc, dtype=torch.float32),
7072
"dummy.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
7173
}
72-
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
74+
quant_config = get_dummy_quant_config(
75+
strategy=strategy, group_size=group_size, symmetric=symmetric
76+
)
7377

7478
compressor = IntQuantizationCompressor(config=quant_config)
7579
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}

tests/test_compressors/quantized_compressors/test_pack_quant.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
from torch.nn.modules import Linear, Sequential
3838

3939

40-
def get_dummy_quant_config(num_bits=4, strategy=None, group_size=None, actorder=None):
40+
def get_dummy_quant_config(
41+
num_bits=4, strategy=None, group_size=None, actorder=None, symmetric=True
42+
):
4143
config_groups = {
4244
"group_1": QuantizationScheme(
4345
targets=["Linear"],
@@ -46,6 +48,7 @@ def get_dummy_quant_config(num_bits=4, strategy=None, group_size=None, actorder=
4648
strategy=strategy,
4749
group_size=group_size,
4850
actorder=actorder,
51+
symmetric=symmetric,
4952
),
5053
),
5154
}
@@ -151,21 +154,25 @@ def test_reload_match(tmp_path, num_bits):
151154
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8),
152155
}
153156

157+
# pack-compressor only needs the number of bits from the quant-args to decompress
158+
# all other information is extracted from the compressed data directly
154159
names_to_scheme = {
155160
"dummy": QuantizationArgs(num_bits=num_bits),
156161
"dummy2": QuantizationArgs(num_bits=num_bits),
157162
}
158-
quant_config = get_dummy_quant_config(num_bits)
163+
quant_config = get_dummy_quant_config(num_bits, symmetric=False)
159164

160165
compressor = PackedQuantizationCompressor(config=quant_config)
161166
quantized_modules_to_args = {
162167
"dummy": quant_config.config_groups["group_1"].weights,
163168
"dummy2": quant_config.config_groups["group_1"].weights,
164169
}
170+
165171
compressed_state_dict = compressor.compress(
166172
dense_state_dict, names_to_scheme=quantized_modules_to_args
167173
)
168174
save_file(compressed_state_dict, tmp_path / "model.safetensors")
175+
169176
reconstructed_dense_gen = compressor.decompress(
170177
tmp_path, names_to_scheme=names_to_scheme
171178
)

0 commit comments

Comments
 (0)