Skip to content

Commit

Permalink
fix quarot bug
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Aug 4, 2024
1 parent 37a9f37 commit 02b1983
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 27 deletions.
1 change: 1 addition & 0 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, model, quant_config, input, config):
self.quant_config = quant_config
self.sparsity_config = quant_config
self.input = input
self.data_free = False if self.input else True
self.config = config
self.block_idx = None
self.num_blocks = len(self.blocks)
Expand Down
53 changes: 28 additions & 25 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,36 +246,39 @@ def block_opt(self, block):
handles = []
self.block_init(block)

for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(
self.cache_input_hook, name=name, feat_dict=input_feat
if not self.data_free:
for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(
self.cache_input_hook, name=name, feat_dict=input_feat
)
)
)
)

if self.quant_out:
self.block_forward(block)
else:
self.input['data'] = self.block_forward(block)

for h in handles:
h.remove()
torch.cuda.empty_cache()
if self.quant_out:
self.block_forward(block)
else:
self.input['data'] = self.block_forward(block)

self.block_transform(block, input_feat, self.input['kwargs'])
for h in handles:
h.remove()
torch.cuda.empty_cache()
self.block_transform(block, input_feat, self.input['kwargs'])
else:
self.block_transform(block)

if self.quant_out:
self.model.replace_module_block(
FakeQuantLinear,
block,
self.block_idx,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
),
)
self.input['data'] = self.block_forward(block)
if not self.data_free:
if self.quant_out:
self.model.replace_module_block(
FakeQuantLinear,
block,
self.block_idx,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
),
)
self.input['data'] = self.block_forward(block)

block = block.cpu()
del input_feat
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_orthogonal_matrix(self):
else:
raise ValueError(f'Unsupported mode {self.mode}')

def block_transform(self, block, input_feat, block_kwargs):
def block_transform(self, block):
logger.info(f'Start transform the {self.block_idx+1}-th block')

if self.online_rotate:
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_quarot_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export PYTHONPATH=$llmc:$PYTHONPATH
task_name=llm_quant_exp

nohup \
python -m llmc --config ../configs/quantization/QuaRot/quarot_w4a4.yml\
python -m llmc --config ../configs/quantization/QuaRot/quarot_w4a4.yml \
> ${task_name}.log 2>&1 &

echo $! > ${task_name}.pid

0 comments on commit 02b1983

Please sign in to comment.