Skip to content

Commit

Permalink
Add Weight48IntegerQuantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Jan 8, 2025
1 parent b4e04eb commit 5a4324f
Showing 1 changed file with 171 additions and 1 deletion.
172 changes: 171 additions & 1 deletion llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from loguru import logger
from torch import nn


class BaseQuantizer(object):
Expand Down Expand Up @@ -1206,3 +1205,174 @@ def __repr__(self):
f'granularity={self.granularity},'
f'kwargs={self.kwargs}, qmin={self.qmin}, qmax={self.qmax})'
)


class Weight48IntegerQuantizer(BaseQuantizer):
# flake8: noqa
def __init__(self, bit, bit4, bit8, **kwargs):
super().__init__(bit, None, None, **kwargs)
self.quant_type = 'int-quant-w48'
assert self.bit == 48, 'Only support 48-bit quantization'
self.bit_settings = {}
self.bit_settings[4] = bit4
self.bit_settings[8] = bit8
for bit in [4, 8]:
if 'int_range' in self.bit_settings[bit]:
self.bit_settings[bit]['qmin'] = self.bit_settings[bit]['int_range'][0]
self.bit_settings[bit]['qmax'] = self.bit_settings[bit]['int_range'][1]
else:
if self.bit_settings[bit]['symmetric']:
self.bit_settings[bit]['qmin'] = -(2 ** (bit - 1))
self.bit_settings[bit]['qmax'] = 2 ** (bit - 1) - 1
else:
self.bit_settings[bit]['qmin'] = 0
self.bit_settings[bit]['qmax'] = 2 ** bit - 1
self.bit_settings[bit]['qmin'] = torch.tensor(self.bit_settings[bit]['qmin'])
self.bit_settings[bit]['qmax'] = torch.tensor(self.bit_settings[bit]['qmax'])
if 'scales_bit' in self.bit_settings[bit]:
if self.bit_settings[bit]['scales_symmetric']:
self.bit_settings[bit]['scales_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1))
self.bit_settings[bit]['scales_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1
else:
self.bit_settings[bit]['scales_qmin'] = 0
self.bit_settings[bit]['scales_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1
else:
self.bit_settings[bit]['scales_qmin'] = -torch.inf
self.bit_settings[bit]['scales_qmax'] = torch.inf
if 'zeros_bit' in self.bit_settings[bit]:
if self.bit_settings[bit]['zeros_symmetric']:
self.bit_settings[bit]['zeros_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1))
self.bit_settings[bit]['zeros_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1
else:
self.bit_settings[bit]['zeros_qmin'] = 0
self.bit_settings[bit]['zeros_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1
else:
self.bit_settings[bit]['zeros_qmin'] = self.bit_settings[bit]['qmin']
self.bit_settings[bit]['zeros_qmax'] = self.bit_settings[bit]['qmax']

def reshape_tensor(self, tensor, bit):
granularity = self.bit_settings[bit].get('granularity')
if granularity == 'per_group':
group_size = self.bit_settings[bit].get('group_size')
if tensor.shape[-1] % group_size == 0:
t = tensor.reshape(-1, group_size)
else:
raise ValueError(
f'Dimension {tensor.shape[-1]} '
f'not divisible by group size {group_size}'
)
else:
t = tensor
return t

def get_qparams(self, tensor_range, device, bit):
min_val, max_val = tensor_range[0], tensor_range[1]
qmin = self.bit_settings[bit]['qmin'].to(device)
qmax = self.bit_settings[bit]['qmax'].to(device)
sym = self.bit_settings[bit]['symmetric']
if sym:
abs_max = torch.max(max_val.abs(), min_val.abs())
abs_max = abs_max.clamp(min=1e-5)
scales = abs_max / qmax
zeros = torch.tensor(0.0)
else:
scales = (max_val - min_val).clamp(min=1e-5) / (qmax - qmin)
zeros = (qmin - torch.round(min_val / scales))
scales = scales.clamp(self.bit_settings[bit]['scales_qmin'], self.bit_settings[bit]['scales_qmax'])
zeros = zeros.clamp(self.bit_settings[bit]['zeros_qmin'], self.bit_settings[bit]['zeros_qmax'])
return scales, zeros, qmax, qmin

def quant(self, tensor, scales, zeros, qmax, qmin):
tensor = torch.clamp(self.round_func(tensor / scales) + zeros, qmin, qmax)
return tensor

def dequant(self, tensor, scales, zeros):
tensor = (tensor - zeros) * scales
return tensor

def quant_dequant(self, tensor, scales, zeros, qmax, qmin):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor

def fake_quant_weight_dynamic(self, weight, args={}):
# step 1: quantize to 8-bit
org_shape16 = weight.shape
org_dtype16 = weight.dtype
weight = self.reshape_tensor(weight, bit=8)
weight_range = self.get_tensor_range(weight)
scales816, zeros816, qmax816, qmin816 = self.get_qparams(weight_range, weight.device, bit=8)
weight = self.quant(weight, scales816, zeros816, qmax816, qmin816)

# step 2: quantize to 4-bit
org_shape8 = weight.shape
org_dtype8 = weight.dtype
weight = self.reshape_tensor(weight, bit=4)
weight_range = self.get_tensor_range(weight)
scales48, zeros48, qmax48, qmin48 = self.get_qparams(weight_range, weight.device, bit=4)
weight = self.quant(weight, scales48, zeros48, qmax48, qmin48)

# step 3: dequantize to 8-bit
weight = self.dequant(weight, scales48, zeros48)
weight = self.restore_tensor(weight, org_shape8).to(org_dtype8)

# step 4: dequantize to 16-bit
weight = self.dequant(weight, scales816, zeros816)
weight = self.restore_tensor(weight, org_shape16).to(org_dtype16)

return weight


if __name__ == '__main__':
def test_Weight48IntegerQuantizer():
torch.manual_seed(0)
torch.cuda.manual_seed(0)

weight = torch.randn(4096, 8192).cuda()
print(weight)

'''
weight:
bit: 48
bit4:
symmetric: False
granularity: per_group
group_size: 128
scales_bit: 8
scales_symmetric: True
zeros_bit: 8
zeros_symmetric: True
bit8:
symmetric: True
granularity: per_channel
int_range: [-120, 120]
'''
cfg = {
'bit': 48,
'bit4': {
'symmetric': False,
'granularity': 'per_group',
'group_size': 128,
'scales_bit': 8,
'scales_symmetric': True,
'zeros_bit': 8,
'zeros_symmetric': True
},
'bit8': {
'symmetric': True,
'granularity': 'per_channel',
'int_range': [-120, 120]
}
}

int_quant = Weight48IntegerQuantizer(**cfg)

int_weight = int_quant.fake_quant_weight_dynamic(weight)

print(int_weight)
from torch import nn
cosine_sim = nn.CosineSimilarity()
cos = cosine_sim(weight.float().view(1, -1), int_weight.float().view(1, -1))
print(cos)

test_Weight48IntegerQuantizer()

0 comments on commit 5a4324f

Please sign in to comment.