From 5a4324f0846a803644f66a5c3375a11478bb4e47 Mon Sep 17 00:00:00 2001 From: helloyongyang Date: Wed, 8 Jan 2025 15:22:58 +0800 Subject: [PATCH] Add Weight48IntegerQuantizer --- llmc/compression/quantization/quant.py | 172 ++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 472f1579..c0c8de1f 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,6 +1,5 @@ import torch from loguru import logger -from torch import nn class BaseQuantizer(object): @@ -1206,3 +1205,174 @@ def __repr__(self): f'granularity={self.granularity},' f'kwargs={self.kwargs}, qmin={self.qmin}, qmax={self.qmax})' ) + + +class Weight48IntegerQuantizer(BaseQuantizer): + # flake8: noqa + def __init__(self, bit, bit4, bit8, **kwargs): + super().__init__(bit, None, None, **kwargs) + self.quant_type = 'int-quant-w48' + assert self.bit == 48, 'Only support 48-bit quantization' + self.bit_settings = {} + self.bit_settings[4] = bit4 + self.bit_settings[8] = bit8 + for bit in [4, 8]: + if 'int_range' in self.bit_settings[bit]: + self.bit_settings[bit]['qmin'] = self.bit_settings[bit]['int_range'][0] + self.bit_settings[bit]['qmax'] = self.bit_settings[bit]['int_range'][1] + else: + if self.bit_settings[bit]['symmetric']: + self.bit_settings[bit]['qmin'] = -(2 ** (bit - 1)) + self.bit_settings[bit]['qmax'] = 2 ** (bit - 1) - 1 + else: + self.bit_settings[bit]['qmin'] = 0 + self.bit_settings[bit]['qmax'] = 2 ** bit - 1 + self.bit_settings[bit]['qmin'] = torch.tensor(self.bit_settings[bit]['qmin']) + self.bit_settings[bit]['qmax'] = torch.tensor(self.bit_settings[bit]['qmax']) + if 'scales_bit' in self.bit_settings[bit]: + if self.bit_settings[bit]['scales_symmetric']: + self.bit_settings[bit]['scales_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1)) + self.bit_settings[bit]['scales_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1 + else: + self.bit_settings[bit]['scales_qmin'] = 0 + self.bit_settings[bit]['scales_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1 + else: + self.bit_settings[bit]['scales_qmin'] = -torch.inf + self.bit_settings[bit]['scales_qmax'] = torch.inf + if 'zeros_bit' in self.bit_settings[bit]: + if self.bit_settings[bit]['zeros_symmetric']: + self.bit_settings[bit]['zeros_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1)) + self.bit_settings[bit]['zeros_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1 + else: + self.bit_settings[bit]['zeros_qmin'] = 0 + self.bit_settings[bit]['zeros_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1 + else: + self.bit_settings[bit]['zeros_qmin'] = self.bit_settings[bit]['qmin'] + self.bit_settings[bit]['zeros_qmax'] = self.bit_settings[bit]['qmax'] + + def reshape_tensor(self, tensor, bit): + granularity = self.bit_settings[bit].get('granularity') + if granularity == 'per_group': + group_size = self.bit_settings[bit].get('group_size') + if tensor.shape[-1] % group_size == 0: + t = tensor.reshape(-1, group_size) + else: + raise ValueError( + f'Dimension {tensor.shape[-1]} ' + f'not divisible by group size {group_size}' + ) + else: + t = tensor + return t + + def get_qparams(self, tensor_range, device, bit): + min_val, max_val = tensor_range[0], tensor_range[1] + qmin = self.bit_settings[bit]['qmin'].to(device) + qmax = self.bit_settings[bit]['qmax'].to(device) + sym = self.bit_settings[bit]['symmetric'] + if sym: + abs_max = torch.max(max_val.abs(), min_val.abs()) + abs_max = abs_max.clamp(min=1e-5) + scales = abs_max / qmax + zeros = torch.tensor(0.0) + else: + scales = (max_val - min_val).clamp(min=1e-5) / (qmax - qmin) + zeros = (qmin - torch.round(min_val / scales)) + scales = scales.clamp(self.bit_settings[bit]['scales_qmin'], self.bit_settings[bit]['scales_qmax']) + zeros = zeros.clamp(self.bit_settings[bit]['zeros_qmin'], self.bit_settings[bit]['zeros_qmax']) + return scales, zeros, qmax, qmin + + def quant(self, tensor, scales, zeros, qmax, qmin): + tensor = torch.clamp(self.round_func(tensor / scales) + zeros, qmin, qmax) + return tensor + + def dequant(self, tensor, scales, zeros): + tensor = (tensor - zeros) * scales + return tensor + + def quant_dequant(self, tensor, scales, zeros, qmax, qmin): + tensor = self.quant(tensor, scales, zeros, qmax, qmin) + tensor = self.dequant(tensor, scales, zeros) + return tensor + + def fake_quant_weight_dynamic(self, weight, args={}): + # step 1: quantize to 8-bit + org_shape16 = weight.shape + org_dtype16 = weight.dtype + weight = self.reshape_tensor(weight, bit=8) + weight_range = self.get_tensor_range(weight) + scales816, zeros816, qmax816, qmin816 = self.get_qparams(weight_range, weight.device, bit=8) + weight = self.quant(weight, scales816, zeros816, qmax816, qmin816) + + # step 2: quantize to 4-bit + org_shape8 = weight.shape + org_dtype8 = weight.dtype + weight = self.reshape_tensor(weight, bit=4) + weight_range = self.get_tensor_range(weight) + scales48, zeros48, qmax48, qmin48 = self.get_qparams(weight_range, weight.device, bit=4) + weight = self.quant(weight, scales48, zeros48, qmax48, qmin48) + + # step 3: dequantize to 8-bit + weight = self.dequant(weight, scales48, zeros48) + weight = self.restore_tensor(weight, org_shape8).to(org_dtype8) + + # step 4: dequantize to 16-bit + weight = self.dequant(weight, scales816, zeros816) + weight = self.restore_tensor(weight, org_shape16).to(org_dtype16) + + return weight + + +if __name__ == '__main__': + def test_Weight48IntegerQuantizer(): + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + weight = torch.randn(4096, 8192).cuda() + print(weight) + + ''' + weight: + bit: 48 + bit4: + symmetric: False + granularity: per_group + group_size: 128 + scales_bit: 8 + scales_symmetric: True + zeros_bit: 8 + zeros_symmetric: True + bit8: + symmetric: True + granularity: per_channel + int_range: [-120, 120] + ''' + cfg = { + 'bit': 48, + 'bit4': { + 'symmetric': False, + 'granularity': 'per_group', + 'group_size': 128, + 'scales_bit': 8, + 'scales_symmetric': True, + 'zeros_bit': 8, + 'zeros_symmetric': True + }, + 'bit8': { + 'symmetric': True, + 'granularity': 'per_channel', + 'int_range': [-120, 120] + } + } + + int_quant = Weight48IntegerQuantizer(**cfg) + + int_weight = int_quant.fake_quant_weight_dynamic(weight) + + print(int_weight) + from torch import nn + cosine_sim = nn.CosineSimilarity() + cos = cosine_sim(weight.float().view(1, -1), int_weight.float().view(1, -1)) + print(cos) + + test_Weight48IntegerQuantizer()