From 8bc7f28fc2be7c90e0668a13fae1461fedb6a3b7 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 13 Jan 2025 15:28:26 +0800 Subject: [PATCH 1/3] Support ShadowKV and fix bugs --- config.yml | 46 ++ .../backend/sglang/fp8/awq_fp8.yml | 3 +- .../backend/sglang/fp8/awq_fp8_static.yml | 3 +- .../backend/sglang/fp8/gptq_fp8.yml | 3 +- .../backend/sglang/fp8/rtn_fp8.yml | 3 +- .../backend/sglang/fp8/smoothquant_fp8.yml | 3 +- .../quantization/backend/vllm/fp8/awq_fp8.yml | 3 +- .../backend/vllm/fp8/awq_fp8_static.yml | 3 +- .../backend/vllm/fp8/gptq_fp8.yml | 3 +- .../quantization/backend/vllm/fp8/rtn_fp8.yml | 3 +- .../backend/vllm/fp8/smoothquant_fp8.yml | 3 +- .../methods/FP_Quant/awq_we2m1a16_g128.yml | 2 +- .../methods/FP_Quant/gptq_we2m1a16_g128.yml | 2 +- .../methods/FP_Quant/rtn_we2m1a16_g128.yml | 2 +- .../methods/FP_Quant/rtn_we2m1ae2m1.yml | 3 +- .../methods/FP_Quant/rtn_we4m3ae4m3.yml | 3 +- .../methods/FP_Quant/rtn_we5m2ae5m2.yml | 3 +- .../methods/KVQuant/rtn_w_a_kivi_quant_kv.yml | 4 +- .../KVQuant/rtn_w_a_naive_quant_kv.yml | 4 +- .../methods/KVQuant/rtn_w_a_sink_quant_kv.yml | 38 - .../methods/RTN/rtn_w_a_pertensor_static.yml | 1 + .../methods/RTN/rtn_w_a_wint4aint8.yml | 41 ++ .../methods/Kvsparse/shadowkv.yml | 22 + .../methods/Kvsparse/sinkkv.yml | 25 + .../methods/Magnitude/magnitude.yml | 3 +- .../methods/ShortGPT/shortgpt.yml | 3 +- .../sparsification/methods/Wanda/wanda.yml | 3 +- llmc/__main__.py | 24 +- llmc/compression/blockwise_optimization.py | 64 +- llmc/compression/quantization/__init__.py | 2 +- llmc/compression/quantization/attn_utils.py | 402 +++++++++++ .../base_blockwise_quantization.py | 51 +- llmc/compression/quantization/kvquant.py | 165 ----- llmc/compression/quantization/module_utils.py | 397 ----------- llmc/compression/quantization/quant.py | 174 ++++- llmc/compression/sparsification/__init__.py | 3 +- llmc/compression/sparsification/attn_utils.py | 144 ++++ .../base_blockwise_sparsification.py | 130 +++- llmc/compression/sparsification/dense.py | 16 + llmc/compression/sparsification/kvsparse.py | 653 ++++++++++++++++++ llmc/compression/sparsification/shortgpt.py | 13 +- llmc/compression/sparsification/sparse.py | 9 - llmc/eval/eval_ppl.py | 1 - llmc/eval/utils.py | 2 +- 44 files changed, 1746 insertions(+), 739 deletions(-) create mode 100644 config.yml delete mode 100644 configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml create mode 100644 configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml create mode 100644 configs/sparsification/methods/Kvsparse/shadowkv.yml create mode 100644 configs/sparsification/methods/Kvsparse/sinkkv.yml create mode 100644 llmc/compression/quantization/attn_utils.py create mode 100644 llmc/compression/sparsification/attn_utils.py create mode 100644 llmc/compression/sparsification/dense.py create mode 100644 llmc/compression/sparsification/kvsparse.py delete mode 100644 llmc/compression/sparsification/sparse.py diff --git a/config.yml b/config.yml new file mode 100644 index 00000000..f62fb265 --- /dev/null +++ b/config.yml @@ -0,0 +1,46 @@ +base: + seed: &seed 42 +model: + type: Qwen2 + path: /home/gushiqiao/nvme/gushiqiao/bussinesss/code_72b/SenseChat-Code-Tmp + tokenizer_mode: fast + torch_dtype: auto +calib: + name: pileval + download: False + path: /home/gushiqiao/nvme/gushiqiao/llm_datasets/calib/pileval + n_samples: 256 + bs: -1 + seq_len: 512 + preproc: txt_general_preproc + seed: *seed +# eval: +# - eval_pos: [ fake_quant] +# name: wikitext2 +# download: False +# path: /home/gushiqiao/nvme/gushiqiao/llm_datasets/eval/wikitext2 +# seq_len: 2048 +# # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False". +# # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True". +# bs: 10 +# inference_per_block: True +quant: + method: Awq + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: False + awq_bs: 128 + quant_out: True +save: + save_trans: True + save_path: ./awq_test_new_pileval_down_ov/ diff --git a/configs/quantization/backend/sglang/fp8/awq_fp8.yml b/configs/quantization/backend/sglang/fp8/awq_fp8.yml index 8cdc4a63..b2f5396b 100644 --- a/configs/quantization/backend/sglang/fp8/awq_fp8.yml +++ b/configs/quantization/backend/sglang/fp8/awq_fp8.yml @@ -26,14 +26,15 @@ eval: inference_per_block: False quant: method: Awq - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/sglang/fp8/awq_fp8_static.yml b/configs/quantization/backend/sglang/fp8/awq_fp8_static.yml index 980c1940..6c86e55a 100644 --- a/configs/quantization/backend/sglang/fp8/awq_fp8_static.yml +++ b/configs/quantization/backend/sglang/fp8/awq_fp8_static.yml @@ -26,14 +26,15 @@ eval: inference_per_block: False quant: method: Awq - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_tensor use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/sglang/fp8/gptq_fp8.yml b/configs/quantization/backend/sglang/fp8/gptq_fp8.yml index 0b592396..85b4bde2 100644 --- a/configs/quantization/backend/sglang/fp8/gptq_fp8.yml +++ b/configs/quantization/backend/sglang/fp8/gptq_fp8.yml @@ -26,14 +26,15 @@ eval: inference_per_block: False quant: method: GPTQ - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/sglang/fp8/rtn_fp8.yml b/configs/quantization/backend/sglang/fp8/rtn_fp8.yml index 2d34b706..849973fe 100644 --- a/configs/quantization/backend/sglang/fp8/rtn_fp8.yml +++ b/configs/quantization/backend/sglang/fp8/rtn_fp8.yml @@ -17,13 +17,14 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_token diff --git a/configs/quantization/backend/sglang/fp8/smoothquant_fp8.yml b/configs/quantization/backend/sglang/fp8/smoothquant_fp8.yml index 85b823ad..e0caa7ae 100644 --- a/configs/quantization/backend/sglang/fp8/smoothquant_fp8.yml +++ b/configs/quantization/backend/sglang/fp8/smoothquant_fp8.yml @@ -22,14 +22,15 @@ eval: seq_len: 2048 quant: method: SmoothQuant - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/vllm/fp8/awq_fp8.yml b/configs/quantization/backend/vllm/fp8/awq_fp8.yml index 805c7f45..3a282259 100644 --- a/configs/quantization/backend/vllm/fp8/awq_fp8.yml +++ b/configs/quantization/backend/vllm/fp8/awq_fp8.yml @@ -26,14 +26,15 @@ eval: inference_per_block: False quant: method: Awq - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml b/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml index df0cb334..1f09e77c 100644 --- a/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml +++ b/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml @@ -26,8 +26,8 @@ eval: inference_per_block: False quant: method: Awq - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True @@ -35,6 +35,7 @@ quant: use_qtorch: True act: # Support ["e4m3", "e5m2"] + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_tensor diff --git a/configs/quantization/backend/vllm/fp8/gptq_fp8.yml b/configs/quantization/backend/vllm/fp8/gptq_fp8.yml index 163f65c4..ec5db89f 100644 --- a/configs/quantization/backend/vllm/fp8/gptq_fp8.yml +++ b/configs/quantization/backend/vllm/fp8/gptq_fp8.yml @@ -26,14 +26,15 @@ eval: inference_per_block: False quant: method: GPTQ - quant_type: float_quant weight: + quant_type: float_quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float_quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/backend/vllm/fp8/rtn_fp8.yml b/configs/quantization/backend/vllm/fp8/rtn_fp8.yml index d982f5d6..f06f492e 100644 --- a/configs/quantization/backend/vllm/fp8/rtn_fp8.yml +++ b/configs/quantization/backend/vllm/fp8/rtn_fp8.yml @@ -17,13 +17,14 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_token diff --git a/configs/quantization/backend/vllm/fp8/smoothquant_fp8.yml b/configs/quantization/backend/vllm/fp8/smoothquant_fp8.yml index 1c41dc1c..1c97ce11 100644 --- a/configs/quantization/backend/vllm/fp8/smoothquant_fp8.yml +++ b/configs/quantization/backend/vllm/fp8/smoothquant_fp8.yml @@ -22,14 +22,15 @@ eval: seq_len: 2048 quant: method: SmoothQuant - quant_type: float-quant weight: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/methods/FP_Quant/awq_we2m1a16_g128.yml b/configs/quantization/methods/FP_Quant/awq_we2m1a16_g128.yml index 76d203dd..5cce5f08 100644 --- a/configs/quantization/methods/FP_Quant/awq_we2m1a16_g128.yml +++ b/configs/quantization/methods/FP_Quant/awq_we2m1a16_g128.yml @@ -25,8 +25,8 @@ eval: inference_per_block: False quant: method: Awq - quant_type: float-quant weight: + quant_type: float-quant bit: e2m1 symmetric: False granularity: per_group diff --git a/configs/quantization/methods/FP_Quant/gptq_we2m1a16_g128.yml b/configs/quantization/methods/FP_Quant/gptq_we2m1a16_g128.yml index f18de836..d87b1de5 100644 --- a/configs/quantization/methods/FP_Quant/gptq_we2m1a16_g128.yml +++ b/configs/quantization/methods/FP_Quant/gptq_we2m1a16_g128.yml @@ -26,8 +26,8 @@ eval: inference_per_block: False quant: method: GPTQ - quant_type: float-quant weight: + quant_type: float-quant bit: e2m1 symmetric: True granularity: per_group diff --git a/configs/quantization/methods/FP_Quant/rtn_we2m1a16_g128.yml b/configs/quantization/methods/FP_Quant/rtn_we2m1a16_g128.yml index 5e2cc61e..c55be361 100644 --- a/configs/quantization/methods/FP_Quant/rtn_we2m1a16_g128.yml +++ b/configs/quantization/methods/FP_Quant/rtn_we2m1a16_g128.yml @@ -16,8 +16,8 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e2m1 symmetric: True granularity: per_group diff --git a/configs/quantization/methods/FP_Quant/rtn_we2m1ae2m1.yml b/configs/quantization/methods/FP_Quant/rtn_we2m1ae2m1.yml index 53169b36..b5ea6fa2 100644 --- a/configs/quantization/methods/FP_Quant/rtn_we2m1ae2m1.yml +++ b/configs/quantization/methods/FP_Quant/rtn_we2m1ae2m1.yml @@ -16,12 +16,13 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e2m1 symmetric: True granularity: per_channel act: + quant_type: float-quant bit: e2m1 symmetric: True granularity: per_token diff --git a/configs/quantization/methods/FP_Quant/rtn_we4m3ae4m3.yml b/configs/quantization/methods/FP_Quant/rtn_we4m3ae4m3.yml index c203eb19..1604eaf2 100644 --- a/configs/quantization/methods/FP_Quant/rtn_we4m3ae4m3.yml +++ b/configs/quantization/methods/FP_Quant/rtn_we4m3ae4m3.yml @@ -16,12 +16,13 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_channel act: + quant_type: float-quant bit: e4m3 symmetric: True granularity: per_token diff --git a/configs/quantization/methods/FP_Quant/rtn_we5m2ae5m2.yml b/configs/quantization/methods/FP_Quant/rtn_we5m2ae5m2.yml index d90675cb..ed3ea2f4 100644 --- a/configs/quantization/methods/FP_Quant/rtn_we5m2ae5m2.yml +++ b/configs/quantization/methods/FP_Quant/rtn_we5m2ae5m2.yml @@ -16,12 +16,13 @@ eval: inference_per_block: False quant: method: RTN - quant_type: float-quant weight: + quant_type: float-quant bit: e5m2 symmetric: True granularity: per_channel act: + quant_type: float-quant bit: e5m2 symmetric: True granularity: per_token diff --git a/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml index 0a745dbc..9d454387 100644 --- a/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml +++ b/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml @@ -5,14 +5,14 @@ model: path: model path torch_dtype: auto eval: - eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #decode_ppl eval not support pretrain eval pos name: wikitext2 type: decode_ppl download: False path: eval_data_path bs: 1 inference_per_block: False - num_samples: 10 + num_samples: 50 # num_eval_tokens: 3 quant: method: RTN diff --git a/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml index ade36ebc..4a9452e1 100644 --- a/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml +++ b/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml @@ -5,14 +5,14 @@ model: path: model path torch_dtype: auto eval: - eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #decode_ppl eval not support pretrain eval pos name: wikitext2 type: decode_ppl download: False path: eval_data_path bs: 1 inference_per_block: False - num_samples: 10 + num_samples: 50 # num_eval_tokens: 3 quant: method: RTN diff --git a/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml deleted file mode 100644 index 040f260c..00000000 --- a/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml +++ /dev/null @@ -1,38 +0,0 @@ -base: - seed: &seed 42 -model: - type: model_type - path: model path - torch_dtype: auto -eval: - eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos - name: wikitext2 - type: decode_ppl - download: False - path: eval_data_path - bs: 1 - inference_per_block: False - num_samples: 10 - # num_eval_tokens: 3 -quant: - method: RTN - weight: - bit: 8 - symmetric: True - granularity: per_channel - group_size: -1 - act: - bit: 8 - symmetric: True - granularity: per_token - kvcache: - method: Sink - bit: 4 - symmetric: True - granularity: per_token - special: - window_length: 512 - num_sink_tokens: 4 -save: - save_fake: False - save_path: /path/to/save/ diff --git a/configs/quantization/methods/RTN/rtn_w_a_pertensor_static.yml b/configs/quantization/methods/RTN/rtn_w_a_pertensor_static.yml index 46811d1d..64b2eab7 100644 --- a/configs/quantization/methods/RTN/rtn_w_a_pertensor_static.yml +++ b/configs/quantization/methods/RTN/rtn_w_a_pertensor_static.yml @@ -35,6 +35,7 @@ quant: symmetric: True granularity: per_tensor static: True + calib_algo: static_hist save: save_fake: False save_path: /path/to/save/ diff --git a/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml b/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml new file mode 100644 index 00000000..0e105a70 --- /dev/null +++ b/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml @@ -0,0 +1,41 @@ +base: + seed: &seed 42 +model: + type: Llama + path: /mnt/nvme1/yongyang/models/llama2-7b + torch_dtype: auto +eval: + eval_pos: [pretrain, fake_quant] + name: wikitext2 + download: False + path: /mnt/nvme0/yongyang/llm_datasets/llmc/eval/wikitext2 + seq_len: 2048 + # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False". + # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True". + bs: 1 + inference_per_block: False +quant: + method: RTN + weight: + quant_type: int-quant + bit: 48 + bit4: + symmetric: False + granularity: per_group + group_size: 128 + scales_bit: 8 + scales_symmetric: True + zeros_bit: 8 + zeros_symmetric: True + bit8: + symmetric: True + granularity: per_channel + int_range: [-120, 120] + act: + quant_type: int-quant + bit: 8 + symmetric: True + granularity: per_token +save: + save_fake: False + save_path: /path/to/save/ diff --git a/configs/sparsification/methods/Kvsparse/shadowkv.yml b/configs/sparsification/methods/Kvsparse/shadowkv.yml new file mode 100644 index 00000000..2a35ba49 --- /dev/null +++ b/configs/sparsification/methods/Kvsparse/shadowkv.yml @@ -0,0 +1,22 @@ +base: + seed: &seed 42 +model: + type: model_type + path: model path + torch_dtype: torch.bfloat16 +eval: + eval_pos: [transformed] + name: wikitext2 + download: False + path: eval_data_path + bs: 1 + seq_len: 2048 +sparse: + method: Dense + kvcache: + method: ShadowKV + replace_attn: True + sparsity_out: False +save: + save_trans: False + save_path: ./save diff --git a/configs/sparsification/methods/Kvsparse/sinkkv.yml b/configs/sparsification/methods/Kvsparse/sinkkv.yml new file mode 100644 index 00000000..c7800cc1 --- /dev/null +++ b/configs/sparsification/methods/Kvsparse/sinkkv.yml @@ -0,0 +1,25 @@ +base: + seed: &seed 42 +model: + type: model_type + path: model path + torch_dtype: torch.bfloat16 +eval: + eval_pos: [transformed] + name: wikitext2 + type: decode_ppl + download: False + path: eval_data_path + bs: 1 + inference_per_block: False + num_samples: 50 + # num_eval_tokens: 3 +sparse: + method: Dense + kvcache: + method: SinkKV + window_length: 256 + num_sink_tokens: 4 +save: + save_fake: False + save_path: /path/to/save/ diff --git a/configs/sparsification/methods/Magnitude/magnitude.yml b/configs/sparsification/methods/Magnitude/magnitude.yml index c0543e0a..98c23eb0 100644 --- a/configs/sparsification/methods/Magnitude/magnitude.yml +++ b/configs/sparsification/methods/Magnitude/magnitude.yml @@ -25,6 +25,5 @@ sparse: weight: sparsity: 0.5 save: - save_fp: False - save_lightllm: False + save_trans: False save_path: ./save diff --git a/configs/sparsification/methods/ShortGPT/shortgpt.yml b/configs/sparsification/methods/ShortGPT/shortgpt.yml index 1f1b78f3..05f285a9 100644 --- a/configs/sparsification/methods/ShortGPT/shortgpt.yml +++ b/configs/sparsification/methods/ShortGPT/shortgpt.yml @@ -24,7 +24,6 @@ sparse: weight: n_prune_layers: 9 save: - save_trans: True - save_fp: False + save_trans: False save_lightllm: False save_path: ./save diff --git a/configs/sparsification/methods/Wanda/wanda.yml b/configs/sparsification/methods/Wanda/wanda.yml index a2b39868..a1082bd3 100644 --- a/configs/sparsification/methods/Wanda/wanda.yml +++ b/configs/sparsification/methods/Wanda/wanda.yml @@ -26,6 +26,5 @@ sparse: sparsity: 0.5 sparsity_out: False save: - save_fp: False - save_lightllm: False + save_trans: False save_path: ./save diff --git a/llmc/__main__.py b/llmc/__main__.py index 765a4e35..01077eb6 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -45,13 +45,22 @@ def main(config): for modality in get_modality(config): model.set_modality(modality) if not config.get('calib', False): - blockwise_opt = ALGO_REGISTRY[config.quant.method]( - model, - quant_config=config.quant, - input=None, - padding_mask=None, - config=config, - ) + if not config.get('sparse', False): + blockwise_opt = ALGO_REGISTRY[config.quant.method]( + model, + config.quant, + None, + None, + config, + ) + else: + blockwise_opt = ALGO_REGISTRY[config.sparse.method]( + model, + config.sparse, + None, + None, + config, + ) blockwise_opt.run_block_loop() dist.barrier() else: @@ -98,6 +107,7 @@ def main(config): ) eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant') + eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant_wo_kv') if 'save' in config and config.save.get('save_fake', False): blockwise_opt.deploy('fake_quant') diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 4bc0a7a1..72823d1b 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -6,11 +6,11 @@ class BlockwiseOpt(metaclass=ABCMeta): - def __init__(self, model, quant_config, input, padding_mask, config): + def __init__(self, model, compress_config, input, padding_mask, config): self.model = model self.blocks = model.get_blocks() - self.quant_config = quant_config - self.sparsity_config = quant_config + self.quant_config = compress_config + self.sparsity_config = compress_config self.input = input self.padding_mask = padding_mask self.data_free = False if self.input else True @@ -60,37 +60,41 @@ def cache_input_hook(self, m, x, y, name, feat_dict): else: feat_dict[name].append(tuple(inputs)) - def kv_cache_input_hook(self): + def kv_cache_input_hook(self, attn_layer): def hook_fn(module, args, kwargs): kvcache = getattr(module, 'kvcache') kwargs['past_key_value'] = kvcache - kwargs['use_cache'] = True - if kwargs['hidden_states'].shape[1] == 1: - if kwargs['position_ids'].shape[1] == 1: - # For eval decoding PPL (Perplexity), it will be removed in future versions. - past_seen_tokens = kvcache.get_seq_length() - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + kwargs['hidden_states'].shape[1], - device=kwargs['hidden_states'].device, + if self.config.eval.get('type', None) == 'decode_ppl': + # For eval decoding PPL (Perplexity). + past_seen_tokens = kvcache.get_seq_length() + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + kwargs['hidden_states'].shape[1], + device=kwargs['hidden_states'].device, + ) + kwargs['cache_position'] = cache_position + position_ids = cache_position.unsqueeze(0) + kwargs['position_ids'] = position_ids + if 'position_embeddings' in kwargs: + kwargs['position_embeddings'] = self.model.rotary_emb( + kwargs['hidden_states'], position_ids ) - kwargs['cache_position'] = cache_position - position_ids = cache_position.unsqueeze(0) - kwargs['position_ids'] = position_ids - if 'position_embeddings' in kwargs: - kwargs['position_embeddings'] = self.model.rotary_emb( - kwargs['hidden_states'], position_ids - ) - else: - if self.config['model']['type'] in ['DeepseekV2']: - kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1) - else: - kwargs['position_ids'] = \ - kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0) - if 'position_embeddings' in kwargs: - cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1) - sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1) - kwargs['position_embeddings'] = (cos, sin) + if kwargs['hidden_states'].shape[1] == 1: + from .sparsification.kvsparse import ShadowKVCache + if isinstance(kvcache, ShadowKVCache): + hidden_states = kwargs['hidden_states'][:, -1, :].unsqueeze(0) + kwargs['hidden_states'] = hidden_states + bsz, q_len, _ = hidden_states.size() + tmp_query_states = \ + attn_layer.q_proj(hidden_states).view(bsz, + q_len, + -1, + attn_layer.head_dim).transpose(1, 2) + retrieval_position_ids = \ + kvcache.get_retrieval_position_ids(layer_idx=attn_layer.layer_idx, + query_states=tmp_query_states) + kwargs['retrieval_position_ids'] = retrieval_position_ids + kwargs['cos_sin_cache'] = self.cos_sin_cache return args, kwargs diff --git a/llmc/compression/quantization/__init__.py b/llmc/compression/quantization/__init__.py index feefa41e..2c08343e 100644 --- a/llmc/compression/quantization/__init__.py +++ b/llmc/compression/quantization/__init__.py @@ -4,7 +4,7 @@ from .dgq import DGQ from .gptq import GPTQ from .hqq import HQQ -from .kvquant import NaiveQuantKVCache +from .kvquant import KiviQuantKVCache, NaiveQuantKVCache from .llmint8 import LlmInt8 from .module_utils import FakeQuantLinear from .ntweak import NormTweaking diff --git a/llmc/compression/quantization/attn_utils.py b/llmc/compression/quantization/attn_utils.py new file mode 100644 index 00000000..8eb4282e --- /dev/null +++ b/llmc/compression/quantization/attn_utils.py @@ -0,0 +1,402 @@ +import math + +import torch +import torch.nn as nn + + +class LlmcMatmul(nn.Module): + def __init__(self, a1_qdq=None, a2_qdq=None): + super().__init__() + self.a1_qdq = a1_qdq + self.a2_qdq = a2_qdq + self.calib = True + + def forward(self, x1, x2): + if self.a1_qdq is not None and not self.calib: + x1 = self.a1_qdq(x1, self) + if self.a2_qdq is not None and not self.calib: + x2 = self.a2_qdq(x2, self) + out = torch.matmul(x1, x2) + return out + + def __repr__(self): + return f'LlmcMatmul(calib={self.calib})' + + +class LlmcSoftmax(nn.Module): + def __init__(self, a_qdq=None): + super().__init__() + self.a_qdq = a_qdq + self.calib = True + + def forward(self, x, dim=-1, dtype=None): + if self.a_qdq is not None and not self.calib: + x = self.a_qdq(x, self) + out = nn.functional.softmax(x, dim=dim, dtype=dtype) + return out + + def __repr__(self): + return f'LlmcSoftmax(calib={self.calib})' + + +class LlmcViTSelfAttention(nn.Module): + def __init__( + self, + query, + key, + value, + num_attention_heads, + attention_head_size, + all_head_size, + dropout, + matmul_a1_qdq, + matmul_a2_qdq, + softmax_a_qdq, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = all_head_size + self.query = query + self.key = key + self.value = value + + self.dropout = dropout + + self.matmul_1 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) + self.matmul_2 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) + self.softmax = LlmcSoftmax(softmax_a_qdq) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = self.matmul_1(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_probs = self.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = self.matmul_2(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + @classmethod + @torch.no_grad() + def new(cls, module, matmul_a1_qdq=None, matmul_a2_qdq=None, softmax_a_qdq=None): + query, key, value = module.query, module.key, module.value + num_attention_heads = module.num_attention_heads + attention_head_size = module.attention_head_size + all_head_size = module.all_head_size + dropout = module.dropout + new_module = cls( + query, + key, + value, + num_attention_heads, + attention_head_size, + all_head_size, + dropout, + matmul_a1_qdq, + matmul_a2_qdq, + softmax_a_qdq, + ) + return new_module + + def __repr__(self): + return ( + f'LlmcViTSelfAttention(\n' + f' (query): {self.query}\n' + f' (key): {self.key}\n' + f' (value): {self.value}\n' + f' (dropout): {self.dropout}\n' + f' (matmul_1): {self.matmul_1}\n' + f' (matmul_2): {self.matmul_2}\n' + f' (softmax): {self.softmax}\n' + f')' + ) + + +class LlmcDeepseekAttention(nn.Module): + def __init__( + self, + config, + layer_idx, + attention_dropout, + hidden_size, + num_heads, + max_position_embeddings, + rope_theta, + q_lora_rank, + qk_rope_head_dim, + kv_lora_rank, + v_head_dim, + qk_nope_head_dim, + q_head_dim, + is_causal, + q_proj, + q_a_proj, + q_a_layernorm, + q_b_proj, + kv_a_proj_with_mqa, + kv_a_layernorm, + kv_b_proj, + o_proj, + rotary_emb, + softmax_scale, + matmul_a1_qdq, + matmul_a2_qdq, + softmax_a_qdq, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.kv_lora_rank = kv_lora_rank + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.q_head_dim = q_head_dim + self.is_causal = is_causal + self.q_proj = q_proj + self.q_a_proj = q_a_proj + self.q_a_layernorm = q_a_layernorm + self.q_b_proj = q_b_proj + self.kv_a_proj_with_mqa = kv_a_proj_with_mqa + self.kv_a_layernorm = kv_a_layernorm + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.rotary_emb = rotary_emb + self.softmax_scale = softmax_scale + self.matmul_1 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) + self.matmul_2 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) + self.softmax = LlmcSoftmax(softmax_a_qdq) + + def _shape(self, tensor, seq_len, bsz): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + @classmethod + @torch.no_grad() + def new(cls, module, matmul_a1_qdq=None, matmul_a2_qdq=None, softmax_a_qdq=None): + + config = module.config + layer_idx = module.layer_idx + + attention_dropout = module.config.attention_dropout + hidden_size = module.config.hidden_size + num_heads = module.config.num_attention_heads + + max_position_embeddings = module.config.max_position_embeddings + rope_theta = module.config.rope_theta + q_lora_rank = module.config.q_lora_rank + qk_rope_head_dim = module.config.qk_rope_head_dim + kv_lora_rank = module.config.kv_lora_rank + v_head_dim = module.config.v_head_dim + qk_nope_head_dim = module.config.qk_nope_head_dim + q_head_dim = module.q_head_dim + is_causal = module.is_causal + + if q_lora_rank is None: + q_proj = module.q_proj + q_a_proj = None + q_a_layernorm = None + q_b_proj = None + else: + q_proj = None + q_a_proj = module.q_a_proj + q_a_layernorm = module.q_a_layernorm + q_b_proj = module.q_b_proj + + kv_a_proj_with_mqa = module.kv_a_proj_with_mqa + kv_a_layernorm = module.kv_a_layernorm + kv_b_proj = module.kv_b_proj + + o_proj = module.o_proj + rotary_emb = module.rotary_emb + + softmax_scale = module.softmax_scale + + new_module = cls( + config=config, + layer_idx=layer_idx, + attention_dropout=attention_dropout, + hidden_size=hidden_size, + num_heads=num_heads, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + q_lora_rank=q_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + kv_lora_rank=kv_lora_rank, + v_head_dim=v_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + q_head_dim=q_head_dim, + is_causal=is_causal, + q_proj=q_proj, + q_a_proj=q_a_proj, + q_a_layernorm=q_a_layernorm, + q_b_proj=q_b_proj, + kv_a_proj_with_mqa=kv_a_proj_with_mqa, + kv_a_layernorm=kv_a_layernorm, + kv_b_proj=kv_b_proj, + o_proj=o_proj, + rotary_emb=rotary_emb, + softmax_scale=softmax_scale, + matmul_a1_qdq=matmul_a1_qdq, + matmul_a2_qdq=matmul_a2_qdq, + softmax_a_qdq=softmax_a_qdq, + ) + + return new_module + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = self.apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim:] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + self.matmul_1(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},' + f'but is {attn_weights.size()}' + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},' + f'but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = self.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = self.matmul_2(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)},' + f' but is {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +_LLMC_ATTN_MAP_ = {'Vit': LlmcViTSelfAttention, 'DeepseekV2': LlmcDeepseekAttention} diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5925c111..d4773030 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -11,19 +11,18 @@ import torch.nn as nn from loguru import logger -from llmc.utils import copy_files from llmc.utils.registry_factory import KV_REGISTRY from ..blockwise_optimization import BlockwiseOpt +from .attn_utils import _LLMC_ATTN_MAP_ from .auto_clip import AutoClipper from .hadamard_utils import apply_exact_had_to_linear, get_hadK -from .module_utils import (_LLMC_ATTN_MAP_, _LLMC_LINEAR_TYPES_, - _LLMC_LN_TYPES_, _REALQUANT_LINEAR_MAP_, - _TRANSFORMERS_LINEAR_TYPES_, +from .module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, + _REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_, _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer +from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer from .utils import check_do_quant, check_w_only, get_aquantizer, get_wquantizer @@ -175,22 +174,32 @@ def set_quant_config(self): self.tp = self.quant_config.get('tp', 1) self.quant_config['weight']['tp'] = self.tp - # select quant module - self.quant_type = self.quant_config.get('quant_type', 'int-quant') - if self.quant_type == 'int-quant': - self.quant_module = IntegerQuantizer - elif self.quant_type == 'float-quant': - self.quant_module = FloatQuantizer - logger.info(f'The used Quant Module is {self.quant_module}') - - # set weight quant config - self.wquantizer = self.quant_module(**self.quant_config['weight']) + # select quantizer + # weight + quant_type = self.quant_config['weight'].get('quant_type', 'int-quant') + if quant_type == 'int-quant': + if self.quant_config['weight']['bit'] == 48: + self.weight_quant_module = Weight48IntegerQuantizer + else: + self.weight_quant_module = IntegerQuantizer + elif quant_type == 'float-quant': + self.weight_quant_module = FloatQuantizer + logger.info(f'The used Weight Quant Module is {self.weight_quant_module}') + self.wquantizer = self.weight_quant_module(**self.quant_config['weight']) - # set act quant config + # act if 'act' in self.quant_config: self.w_only = False + quant_type = self.quant_config['act'].get('quant_type', 'int-quant') + if quant_type == 'int-quant': + if self.quant_config['act']['bit'] == 48: + self.act_quant_module = Weight48IntegerQuantizer + else: + self.act_quant_module = IntegerQuantizer + elif quant_type == 'float-quant': + self.act_quant_module = FloatQuantizer self.quant_config['act']['tp'] = self.tp - self.aquantizer = self.quant_module(**self.quant_config['act']) + self.aquantizer = self.act_quant_module(**self.quant_config['act']) self.act_static = self.quant_config['act'].get('static', False) if self.act_static: assert ( @@ -230,7 +239,6 @@ def set_quant_config(self): if 'kvcache' in self.quant_config: self.quant_config['kvcache']['static'] = self.act_static kv_special_cfg = self.quant_config['kvcache'].get('special', {}) - logger.info(kv_special_cfg) act_static_cfg = {} if self.act_static: act_static_cfg.update(self.config.calib.n_sample) @@ -546,7 +554,7 @@ def register_kv_cache(self, block): attn_layer = attn_layers_dict[list(attn_layers_dict.keys())[0]] setattr(attn_layer, 'kvcache', self.kv_module) attn_layer.register_forward_pre_hook( - self.kv_cache_input_hook(), with_kwargs=True + self.kv_cache_input_hook(attn_layer), with_kwargs=True ) @torch.no_grad() @@ -900,10 +908,7 @@ def deploy(self, quant_format, keep_device=False): @torch.no_grad() def copy_tokenizer(self, path): - for substring in self.config.save.get( - 'tokenizer_file_substring', ['token', 'merges', 'vocab', 'preprocessor_config', 'chat_template'] # noqa - ): - copy_files(self.config.model.path, path, substring) + self.model.tokenizer.save_pretrained(path) logger.info('copy tokenizer done --') @torch.no_grad() diff --git a/llmc/compression/quantization/kvquant.py b/llmc/compression/quantization/kvquant.py index ed67a04c..32c2de5b 100644 --- a/llmc/compression/quantization/kvquant.py +++ b/llmc/compression/quantization/kvquant.py @@ -287,168 +287,3 @@ def update( torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return keys_to_return, values_to_return - - -@KV_REGISTRY.register('Sink') -class SinkQuantKVCache(NaiveQuantKVCache): - def __init__( - self, - quant_type, - kvquant_cfg, - num_hidden_layers, - window_length, - num_sink_tokens, - num_samples=128, - bsz=1 - ): - super().__init__(quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz) - assert not self.static, 'Only support dynamic quantization for Sink' - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - - @staticmethod - def _rotate_half(x): - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states, cos, sin - ): - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states, cos, sin - ): - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - original_cos = cos[self.num_sink_tokens + key_states.shape[-2]:] - shifted_cos = cos[self.num_sink_tokens:-key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2]:] - shifted_sin = sin[self.num_sink_tokens:-key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx=0): - """Returns the sequence length of the cached states. - - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_cache_shape(self): - """Returns the maximum sequence length of the cache object, in case of - SinkCache it is the window length.""" - return self.window_length - - def update( - self, - key_states, - value_states, - layer_idx, - cache_kwargs, - ): - - if self.use_org_kv: - return super().update(key_states, value_states, layer_idx, cache_kwargs) - else: - sin = cache_kwargs.get('sin') - cos = cache_kwargs.get('cos') - partial_rotation_size = cache_kwargs.get('partial_rotation_size') - using_rope = cos is not None and sin is not None - - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - if using_rope and layer_idx == 0: - - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - else: - if self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = \ - torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = \ - torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2]: - ] - - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, - self._cos_cache[: self.window_length], - self._sin_cache[: self.window_length] - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, - rerotation_cos, - rerotation_sin) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - - dq_keys_to_keep = self._dequantize(self._quantize(keys_to_keep.contiguous(), - layer_idx, - is_key=True)) - dq_keys = self._dequantize(self._quantize(key_states.contiguous(), - layer_idx, - is_key=True)) - - self.key_cache[layer_idx] = torch.cat([sink_keys, dq_keys_to_keep, dq_keys], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2]: - ] - dq_values_to_keep = self._dequantize(self._quantize(values_to_keep.contiguous(), - layer_idx, - is_key=True)) - dq_values = self._dequantize(self._quantize(value_states.contiguous(), - layer_idx, - is_key=True)) - - self.value_cache[layer_idx] = torch.cat([sink_values, - dq_values_to_keep, - dq_values], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 8b4f8b62..1710d288 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -21,401 +21,6 @@ from .utils import calculate_zeros_width -class LlmcMatmul(nn.Module): - def __init__(self, a1_qdq=None, a2_qdq=None): - super().__init__() - self.a1_qdq = a1_qdq - self.a2_qdq = a2_qdq - self.calib = True - - def forward(self, x1, x2): - if self.a1_qdq is not None and not self.calib: - x1 = self.a1_qdq(x1, self) - if self.a2_qdq is not None and not self.calib: - x2 = self.a2_qdq(x2, self) - out = torch.matmul(x1, x2) - return out - - def __repr__(self): - return f'LlmcMatmul(calib={self.calib})' - - -class LlmcSoftmax(nn.Module): - def __init__(self, a_qdq=None): - super().__init__() - self.a_qdq = a_qdq - self.calib = True - - def forward(self, x, dim=-1, dtype=None): - if self.a_qdq is not None and not self.calib: - x = self.a_qdq(x, self) - out = nn.functional.softmax(x, dim=dim, dtype=dtype) - return out - - def __repr__(self): - return f'LlmcSoftmax(calib={self.calib})' - - -class LlmcViTSelfAttention(nn.Module): - def __init__( - self, - query, - key, - value, - num_attention_heads, - attention_head_size, - all_head_size, - dropout, - matmul_a1_qdq, - matmul_a2_qdq, - softmax_a_qdq, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_size = attention_head_size - self.all_head_size = all_head_size - self.query = query - self.key = key - self.value = value - - self.dropout = dropout - - self.matmul_1 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) - self.matmul_2 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) - self.softmax = LlmcSoftmax(softmax_a_qdq) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, head_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - attention_scores = self.matmul_1(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - attention_probs = self.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = self.matmul_2(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = ( - (context_layer, attention_probs) if output_attentions else (context_layer,) - ) - - return outputs - - @classmethod - @torch.no_grad() - def new(cls, module, matmul_a1_qdq=None, matmul_a2_qdq=None, softmax_a_qdq=None): - query, key, value = module.query, module.key, module.value - num_attention_heads = module.num_attention_heads - attention_head_size = module.attention_head_size - all_head_size = module.all_head_size - dropout = module.dropout - new_module = cls( - query, - key, - value, - num_attention_heads, - attention_head_size, - all_head_size, - dropout, - matmul_a1_qdq, - matmul_a2_qdq, - softmax_a_qdq, - ) - return new_module - - def __repr__(self): - return ( - f'LlmcViTSelfAttention(\n' - f' (query): {self.query}\n' - f' (key): {self.key}\n' - f' (value): {self.value}\n' - f' (dropout): {self.dropout}\n' - f' (matmul_1): {self.matmul_1}\n' - f' (matmul_2): {self.matmul_2}\n' - f' (softmax): {self.softmax}\n' - f')' - ) - - -class LlmcDeepseekAttention(nn.Module): - def __init__( - self, - config, - layer_idx, - attention_dropout, - hidden_size, - num_heads, - max_position_embeddings, - rope_theta, - q_lora_rank, - qk_rope_head_dim, - kv_lora_rank, - v_head_dim, - qk_nope_head_dim, - q_head_dim, - is_causal, - q_proj, - q_a_proj, - q_a_layernorm, - q_b_proj, - kv_a_proj_with_mqa, - kv_a_layernorm, - kv_b_proj, - o_proj, - rotary_emb, - softmax_scale, - matmul_a1_qdq, - matmul_a2_qdq, - softmax_a_qdq, - ): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.attention_dropout = attention_dropout - self.hidden_size = hidden_size - self.num_heads = num_heads - self.max_position_embeddings = max_position_embeddings - self.rope_theta = rope_theta - self.q_lora_rank = q_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.kv_lora_rank = kv_lora_rank - self.v_head_dim = v_head_dim - self.qk_nope_head_dim = qk_nope_head_dim - self.q_head_dim = q_head_dim - self.is_causal = is_causal - self.q_proj = q_proj - self.q_a_proj = q_a_proj - self.q_a_layernorm = q_a_layernorm - self.q_b_proj = q_b_proj - self.kv_a_proj_with_mqa = kv_a_proj_with_mqa - self.kv_a_layernorm = kv_a_layernorm - self.kv_b_proj = kv_b_proj - self.o_proj = o_proj - self.rotary_emb = rotary_emb - self.softmax_scale = softmax_scale - self.matmul_1 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) - self.matmul_2 = LlmcMatmul(matmul_a1_qdq, matmul_a2_qdq) - self.softmax = LlmcSoftmax(softmax_a_qdq) - - def _shape(self, tensor, seq_len, bsz): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) - .transpose(1, 2) - .contiguous() - ) - - def rotate_half(self, x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=1): - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (self.rotate_half(q) * sin) - k_embed = (k * cos) + (self.rotate_half(k) * sin) - return q_embed, k_embed - - @classmethod - @torch.no_grad() - def new(cls, module, matmul_a1_qdq=None, matmul_a2_qdq=None, softmax_a_qdq=None): - - config = module.config - layer_idx = module.layer_idx - - attention_dropout = module.config.attention_dropout - hidden_size = module.config.hidden_size - num_heads = module.config.num_attention_heads - - max_position_embeddings = module.config.max_position_embeddings - rope_theta = module.config.rope_theta - q_lora_rank = module.config.q_lora_rank - qk_rope_head_dim = module.config.qk_rope_head_dim - kv_lora_rank = module.config.kv_lora_rank - v_head_dim = module.config.v_head_dim - qk_nope_head_dim = module.config.qk_nope_head_dim - q_head_dim = module.q_head_dim - is_causal = module.is_causal - - if q_lora_rank is None: - q_proj = module.q_proj - q_a_proj = None - q_a_layernorm = None - q_b_proj = None - else: - q_proj = None - q_a_proj = module.q_a_proj - q_a_layernorm = module.q_a_layernorm - q_b_proj = module.q_b_proj - - kv_a_proj_with_mqa = module.kv_a_proj_with_mqa - kv_a_layernorm = module.kv_a_layernorm - kv_b_proj = module.kv_b_proj - - o_proj = module.o_proj - rotary_emb = module.rotary_emb - - softmax_scale = module.softmax_scale - - new_module = cls( - config=config, - layer_idx=layer_idx, - attention_dropout=attention_dropout, - hidden_size=hidden_size, - num_heads=num_heads, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - q_lora_rank=q_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - kv_lora_rank=kv_lora_rank, - v_head_dim=v_head_dim, - qk_nope_head_dim=qk_nope_head_dim, - q_head_dim=q_head_dim, - is_causal=is_causal, - q_proj=q_proj, - q_a_proj=q_a_proj, - q_a_layernorm=q_a_layernorm, - q_b_proj=q_b_proj, - kv_a_proj_with_mqa=kv_a_proj_with_mqa, - kv_a_layernorm=kv_a_layernorm, - kv_b_proj=kv_b_proj, - o_proj=o_proj, - rotary_emb=rotary_emb, - softmax_scale=softmax_scale, - matmul_a1_qdq=matmul_a1_qdq, - matmul_a2_qdq=matmul_a2_qdq, - softmax_a_qdq=softmax_a_qdq, - ) - - return new_module - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - q_pe, k_pe = self.apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim:] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe - if past_key_value is not None: - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - attn_weights = ( - self.matmul_1(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},' - f'but is {attn_weights.size()}' - ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},' - f'but is {attention_mask.size()}' - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = self.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = self.matmul_2(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): - raise ValueError( - f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)},' - f' but is {attn_output.size()}' - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - class LlmcActFn(nn.Module): def __init__(self, module, a_qdq) -> None: super().__init__() @@ -1331,8 +936,6 @@ def __repr__(self): LightllmRealQuantLinear, ] -_LLMC_ATTN_MAP_ = {'Vit': LlmcViTSelfAttention, 'DeepseekV2': LlmcDeepseekAttention} - _REALQUANT_LINEAR_MAP_ = { 'vllm_quant': VllmRealQuantLinear, 'lightllm_quant': LightllmRealQuantLinear, diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 472f1579..3d2192fe 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,6 +1,5 @@ import torch from loguru import logger -from torch import nn class BaseQuantizer(object): @@ -38,7 +37,7 @@ def __init__(self, bit, symmetric, granularity, **kwargs): # hist config self.bins = self.kwargs.get('bins', 2048) self.hist_threshold = self.kwargs.get('hist_threshold', 1) - self.dst_nbins = 2**bit + self.dst_nbins = 2**bit if isinstance(bit, int) else None self.upsample_rate = ( 16 # used to reduce quantization errors when upscaling histogram ) @@ -1206,3 +1205,174 @@ def __repr__(self): f'granularity={self.granularity},' f'kwargs={self.kwargs}, qmin={self.qmin}, qmax={self.qmax})' ) + + +class Weight48IntegerQuantizer(BaseQuantizer): + # flake8: noqa + def __init__(self, bit, bit4, bit8, **kwargs): + super().__init__(bit, None, None, **kwargs) + self.quant_type = 'int-quant-w48' + assert self.bit == 48, 'Only support 48-bit quantization' + self.bit_settings = {} + self.bit_settings[4] = bit4 + self.bit_settings[8] = bit8 + for bit in [4, 8]: + if 'int_range' in self.bit_settings[bit]: + self.bit_settings[bit]['qmin'] = self.bit_settings[bit]['int_range'][0] + self.bit_settings[bit]['qmax'] = self.bit_settings[bit]['int_range'][1] + else: + if self.bit_settings[bit]['symmetric']: + self.bit_settings[bit]['qmin'] = -(2 ** (bit - 1)) + self.bit_settings[bit]['qmax'] = 2 ** (bit - 1) - 1 + else: + self.bit_settings[bit]['qmin'] = 0 + self.bit_settings[bit]['qmax'] = 2 ** bit - 1 + self.bit_settings[bit]['qmin'] = torch.tensor(self.bit_settings[bit]['qmin']) + self.bit_settings[bit]['qmax'] = torch.tensor(self.bit_settings[bit]['qmax']) + if 'scales_bit' in self.bit_settings[bit]: + if self.bit_settings[bit]['scales_symmetric']: + self.bit_settings[bit]['scales_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1)) + self.bit_settings[bit]['scales_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1 + else: + self.bit_settings[bit]['scales_qmin'] = 0 + self.bit_settings[bit]['scales_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1 + else: + self.bit_settings[bit]['scales_qmin'] = -torch.inf + self.bit_settings[bit]['scales_qmax'] = torch.inf + if 'zeros_bit' in self.bit_settings[bit]: + if self.bit_settings[bit]['zeros_symmetric']: + self.bit_settings[bit]['zeros_qmin'] = -(2 ** (self.bit_settings[bit]['scales_bit'] - 1)) + self.bit_settings[bit]['zeros_qmax'] = 2 ** (self.bit_settings[bit]['scales_bit'] - 1) - 1 + else: + self.bit_settings[bit]['zeros_qmin'] = 0 + self.bit_settings[bit]['zeros_qmax'] = 2 ** self.bit_settings[bit]['scales_bit'] - 1 + else: + self.bit_settings[bit]['zeros_qmin'] = self.bit_settings[bit]['qmin'] + self.bit_settings[bit]['zeros_qmax'] = self.bit_settings[bit]['qmax'] + + def reshape_tensor(self, tensor, bit): + granularity = self.bit_settings[bit].get('granularity') + if granularity == 'per_group': + group_size = self.bit_settings[bit].get('group_size') + if tensor.shape[-1] % group_size == 0: + t = tensor.reshape(-1, group_size) + else: + raise ValueError( + f'Dimension {tensor.shape[-1]} ' + f'not divisible by group size {group_size}' + ) + else: + t = tensor + return t + + def get_qparams(self, tensor_range, device, bit): + min_val, max_val = tensor_range[0], tensor_range[1] + qmin = self.bit_settings[bit]['qmin'].to(device) + qmax = self.bit_settings[bit]['qmax'].to(device) + sym = self.bit_settings[bit]['symmetric'] + if sym: + abs_max = torch.max(max_val.abs(), min_val.abs()) + abs_max = abs_max.clamp(min=1e-5) + scales = abs_max / qmax + zeros = torch.tensor(0.0) + else: + scales = (max_val - min_val).clamp(min=1e-5) / (qmax - qmin) + zeros = (qmin - torch.round(min_val / scales)) + scales = scales.clamp(self.bit_settings[bit]['scales_qmin'], self.bit_settings[bit]['scales_qmax']) + zeros = zeros.clamp(self.bit_settings[bit]['zeros_qmin'], self.bit_settings[bit]['zeros_qmax']) + return scales, zeros, qmax, qmin + + def quant(self, tensor, scales, zeros, qmax, qmin): + tensor = torch.clamp(self.round_func(tensor / scales) + zeros, qmin, qmax) + return tensor + + def dequant(self, tensor, scales, zeros): + tensor = (tensor - zeros) * scales + return tensor + + def quant_dequant(self, tensor, scales, zeros, qmax, qmin): + tensor = self.quant(tensor, scales, zeros, qmax, qmin) + tensor = self.dequant(tensor, scales, zeros) + return tensor + + def fake_quant_weight_dynamic(self, weight, args={}): + # step 1: quantize to 8-bit + org_shape16 = weight.shape + org_dtype16 = weight.dtype + weight = self.reshape_tensor(weight, bit=8) + weight_range = self.get_tensor_range(weight) + scales816, zeros816, qmax816, qmin816 = self.get_qparams(weight_range, weight.device, bit=8) + weight = self.quant(weight, scales816, zeros816, qmax816, qmin816) + + # step 2: quantize to 4-bit + org_shape8 = weight.shape + org_dtype8 = weight.dtype + weight = self.reshape_tensor(weight, bit=4) + weight_range = self.get_tensor_range(weight) + scales48, zeros48, qmax48, qmin48 = self.get_qparams(weight_range, weight.device, bit=4) + weight = self.quant(weight, scales48, zeros48, qmax48, qmin48) + + # step 3: dequantize to 8-bit + weight = self.dequant(weight, scales48, zeros48) + weight = self.restore_tensor(weight, org_shape8).to(org_dtype8) + + # step 4: dequantize to 16-bit + weight = self.dequant(weight, scales816, zeros816) + weight = self.restore_tensor(weight, org_shape16).to(org_dtype16) + + return weight + + +if __name__ == '__main__': + def test_Weight48IntegerQuantizer(): + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + weight = torch.randn(4096, 8192).cuda() + print(weight) + + ''' + weight: + bit: 48 + bit4: + symmetric: False + granularity: per_group + group_size: 128 + scales_bit: 8 + scales_symmetric: True + zeros_bit: 8 + zeros_symmetric: True + bit8: + symmetric: True + granularity: per_channel + int_range: [-120, 120] + ''' + cfg = { + 'bit': 48, + 'bit4': { + 'symmetric': False, + 'granularity': 'per_group', + 'group_size': 128, + 'scales_bit': 8, + 'scales_symmetric': True, + 'zeros_bit': 8, + 'zeros_symmetric': True + }, + 'bit8': { + 'symmetric': True, + 'granularity': 'per_channel', + 'int_range': [-120, 120] + } + } + + int_quant = Weight48IntegerQuantizer(**cfg) + + int_weight = int_quant.fake_quant_weight_dynamic(weight) + + print(int_weight) + from torch import nn + cosine_sim = nn.CosineSimilarity() + cos = cosine_sim(weight.float().view(1, -1), int_weight.float().view(1, -1)) + print(cos) + + test_Weight48IntegerQuantizer() diff --git a/llmc/compression/sparsification/__init__.py b/llmc/compression/sparsification/__init__.py index b09a1347..acff9921 100644 --- a/llmc/compression/sparsification/__init__.py +++ b/llmc/compression/sparsification/__init__.py @@ -1,5 +1,6 @@ from .base_blockwise_sparsification import BaseBlockwiseSparsification +from .dense import Dense +from .kvsparse import ShadowKVCache, SinkKVCache from .magnitude import Magnitude from .shortgpt import ShortGPT -from .sparse import Sparser from .wanda import Wanda diff --git a/llmc/compression/sparsification/attn_utils.py b/llmc/compression/sparsification/attn_utils.py new file mode 100644 index 00000000..a5102d71 --- /dev/null +++ b/llmc/compression/sparsification/attn_utils.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +from loguru import logger +from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, + repeat_kv) + + +def eager_attention_forward( + module, + query, + key, + value, + attention_mask, + scaling, + dropout, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class ShadowKVAttention(nn.Module): + def __init__(self, module): + super().__init__() + self.config = module.config + self.layer_idx = module.layer_idx + self.head_dim = module.head_dim + self.num_key_value_groups = module.num_key_value_groups + self.scaling = self.head_dim**-0.5 + self.attention_dropout = module.attention_dropout + self.is_causal = True + + self.q_proj = module.q_proj + self.k_proj = module.k_proj + self.v_proj = module.v_proj + self.o_proj = module.o_proj + + def forward( + self, + hidden_states, + position_embeddings, + position_ids, + attention_mask, + past_key_value, + output_attentions, + use_cache, + cache_position, + retrieval_position_ids=None, + cos_sin_cache=None, + **kwargs, + ): + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if past_key_value is not None and past_key_value.prefill: + past_key_value.get_svd(key_states, layer_idx=self.layer_idx) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + key_states, value_states = \ + past_key_value.update(key_states, + value_states, + self.layer_idx, + retrieval_position_ids, + cos_sin_cache) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) + # bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels + # via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. + # An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + @classmethod + @torch.no_grad() + def new(cls, module): + new_module = cls(module) + return new_module + + def __repr__(self): + return ( + f'ShadowKVAttention(\n' + f' (q_proj): {self.q_proj}\n' + f' (k_proj): {self.k_proj}\n' + f' (v_proj): {self.v_proj}\n' + f' (o_proj): {self.o_proj}\n' + f' (kvcache): {self.kvcache}\n' + f')' + ) + + +_LLMC_ATTN_MAP_ = {'ShadowKV': {'Llama': ShadowKVAttention}} diff --git a/llmc/compression/sparsification/base_blockwise_sparsification.py b/llmc/compression/sparsification/base_blockwise_sparsification.py index e62f5f27..a068b920 100644 --- a/llmc/compression/sparsification/base_blockwise_sparsification.py +++ b/llmc/compression/sparsification/base_blockwise_sparsification.py @@ -6,9 +6,10 @@ from loguru import logger from llmc.utils import copy_files +from llmc.utils.registry_factory import KV_REGISTRY from ..blockwise_optimization import BlockwiseOpt -from .sparse import Sparser +from .attn_utils import _LLMC_ATTN_MAP_ class BaseBlockwiseSparsification(BlockwiseOpt): @@ -28,7 +29,67 @@ def set_sparsity_config(self): self.sparsity_out = False logger.info(f'use sparsity_out {self.sparsity_out}') - self.sparser = Sparser(self.sparsity_config['weight']) + # set kv cache sparse config + if 'kvcache' in self.sparsity_config: + self.sparse_kvcache = True + self.set_kv_sparse_config() + else: + self.sparse_kvcache = False + + if 'weight' in self.sparsity_config: + if 'sparsity' in self.sparsity_config['weight']: + self.sparsity = self.sparsity_config['weight']['sparsity'] + self.W_mask = None + elif 'n_prune_layers' in self.sparsity_config: + self.n_prune_layers = self.sparsity_config['weight']['n_prune_layers'] + + def set_kv_sparse_config(self): + kv_sparse_config = {} + if self.sparsity_config['kvcache']['method'] == 'ShadowKV': + assert self.config['model']['type'] in ['Llama'] + assert self.config['eval'].get('type', None) != 'decode_ppl' + inv_freq = \ + self.model.model.model.layers[0].self_attn.rotary_emb.inv_freq.cuda() + cos_cache, sin_cache = self.set_cos_sin_cache(inv_freq) + self.cos_sin_cache = (cos_cache, sin_cache) + kv_sparse_config['config'] = self.model.model_config + elif self.sparsity_config['kvcache']['method'] == 'SinkKV': + kv_sparse_config['num_hidden_layers'] = self.model.model_config.num_hidden_layers + kv_sparse_config['window_length'] = self.sparsity_config['kvcache']['window_length'] + kv_sparse_config['num_sink_tokens'] = self.sparsity_config['kvcache']['num_sink_tokens'] + self.kv_module = KV_REGISTRY[self.sparsity_config['kvcache']['method']](**kv_sparse_config) + self.replace_attn = self.sparsity_config['kvcache'].get('replace_attn', False) + self.model.kvcache_buffer.append(self.kv_module) + + def set_cos_sin_cache(self, inv_freq): + max_length = 64 * 1024 + t = torch.arange(max_length + 1024, device=torch.device('cuda'), dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(torch.bfloat16), emb.sin().to(torch.bfloat16) + + @torch.no_grad() + def register_kv_cache(self, block): + attn_layers_dict = self.model.get_attn_in_block(block) + attn_layer = attn_layers_dict[list(attn_layers_dict.keys())[0]] + setattr(attn_layer, 'kvcache', self.kv_module) + attn_layer.register_forward_pre_hook( + self.kv_cache_input_hook(attn_layer), with_kwargs=True + ) + + def replace_attention(self, block): + attn_layers_dict = self.model.get_attn_in_block(block) + layers_dict = {'layers': attn_layers_dict} + kv_method = self.sparsity_config['kvcache']['method'] + model_type = self.config['model']['type'] + attn_module = _LLMC_ATTN_MAP_[kv_method][model_type] + self.model.replace_module_subset( + attn_module, + block, + layers_dict, + self.block_idx, + {} + ) def block_forward(self, block, input_data=None): output = [] @@ -49,39 +110,48 @@ def block_forward(self, block, input_data=None): return output def block_opt(self, block): + if self.sparse_kvcache: + if self.replace_attn: + self.replace_attention(block) + self.register_kv_cache(block) block = block.cuda() - named_linears = self.model.get_block_linears(block) - logger.info(f'named_linears: {named_linears}') - input_feat = defaultdict(list) - handles = [] - self.block_init(block) - - for name in named_linears: - handles.append( - named_linears[name].register_forward_hook( - functools.partial( - self.cache_input_hook, name=name, feat_dict=input_feat + + if not self.data_free: + named_linears = self.model.get_block_linears(block) + logger.info(f'named_linears: {named_linears}') + input_feat = defaultdict(list) + handles = [] + self.block_init(block) + + for name in named_linears: + handles.append( + named_linears[name].register_forward_hook( + functools.partial( + self.cache_input_hook, name=name, feat_dict=input_feat + ) ) ) - ) - if not self.sparsity_out: - self.input['data'] = self.block_forward(block) - else: - self.block_forward(block) - for h in handles: - h.remove() - torch.cuda.empty_cache() + if not self.sparsity_out: + self.input['data'] = self.block_forward(block) + else: + self.block_forward(block) + for h in handles: + h.remove() + torch.cuda.empty_cache() + + self.block_transform(block, input_feat, self.input['kwargs']) - self.block_transform(block, input_feat, self.input['kwargs']) + if self.sparsity_out: + self.input['data'] = self.block_forward(block) - if self.sparsity_out: - self.input['data'] = self.block_forward(block) + block = block.cpu() + del input_feat + gc.collect() + torch.cuda.empty_cache() - block = block.cpu() - del input_feat - gc.collect() - torch.cuda.empty_cache() + else: + self.block_transform(block) def block_transform(self, block, input_feat, block_kwargs): logger.info(f'Start transform the {self.block_idx+1}-th block') @@ -113,13 +183,11 @@ def filter_subset(self, subset): def deploy(self, deploy_format): logger.info('-- deploy_sparsity_model start --') logger.info(f'sparsity_config : {self.sparsity_config}') - logger.info('-- deploy_sparsity_model done --') @torch.no_grad() def copy_tokenizer(self, path): - for substring in self.config.save.get('tokenizer_file_substring', ['token']): - copy_files(self.config.model.path, path, substring) + self.model.tokenizer.save_pretrained(path) logger.info('copy tokenizer done --') @torch.no_grad() diff --git a/llmc/compression/sparsification/dense.py b/llmc/compression/sparsification/dense.py new file mode 100644 index 00000000..dad9612b --- /dev/null +++ b/llmc/compression/sparsification/dense.py @@ -0,0 +1,16 @@ +from loguru import logger + +from llmc.utils.registry_factory import ALGO_REGISTRY + +from .base_blockwise_sparsification import BaseBlockwiseSparsification + + +@ALGO_REGISTRY +class Dense(BaseBlockwiseSparsification): + def __init__(self, model, sparsity_config, input, padding_mask, config): + super().__init__(model, sparsity_config, input, padding_mask, config) + + def block_transform(self, block): + logger.info(f'Start transform the {self.block_idx+1}-th block') + logger.info(block) + logger.info(f'End transform the {self.block_idx+1}-th block') diff --git a/llmc/compression/sparsification/kvsparse.py b/llmc/compression/sparsification/kvsparse.py new file mode 100644 index 00000000..b251ae47 --- /dev/null +++ b/llmc/compression/sparsification/kvsparse.py @@ -0,0 +1,653 @@ + +import math + +import torch +import torch.nn as nn +from loguru import logger +from transformers import DynamicCache + +from llmc.utils.registry_factory import KV_REGISTRY + + +def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1): + # if position_ids shape is (batch_size, num_heads, seq_len), + # then reshape it to (batch_size*num_heads, seq_len) + if len(position_ids.shape) == 3: + position_ids = position_ids.view(-1, position_ids.size(-1)) + cos = cos[position_ids] + sin = sin[position_ids] + q_embed = (q * cos) + (rotate_half(q) * sin) + + else: + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +@KV_REGISTRY.register('ShadowKV') +class ShadowKVCache(DynamicCache): + """ShadowKV, only for accuracy measurement and understanding, not for + efficiency, please refer to ShadowKV_CPU for the efficient + implementation.""" + + def __init__( + self, + config, + batch_size=1, + max_length=32 * 1024, + device='cuda:0', + dtype=torch.bfloat16, + sparse_budget=1024, + chunk_size=8, + rank=160, + outlier_chunk=48 + ): + + super().__init__() + self.config = config + self.batch_size = batch_size + self.max_length = max_length + self.device = device + self.dtype = dtype + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.sparse_budget = int(sparse_budget) + self.chunk_size = chunk_size + self.rank = rank + self.local_chunk = 4 + self.outlier_chunk = outlier_chunk + + assert self.batch_size == 1, 'ShadowKV class only supports batch_size=1' + + self.selected_chunk_idx = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget // self.chunk_size, + device=self.device, + dtype=torch.long, + ) + + self.v_cache_cpu = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.max_length, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype, + ) + + self.k_cache_buffer = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget + 4096, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype, + ) + + self.v_cache_buffer = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget + 4096, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype, + ) + + self.num_layers = config.num_hidden_layers + self.kv_offset = 0 + self.prefill = 0 + self.gen_offset = 0 + + self.k_landmark = None + self.k_landmark_idx = None + self.U = None + self.SV = None + + self.copy_stream = torch.cuda.Stream() + self.prefill = True + self.prefill_layers = 0 + + def update( + self, + key_states, + value_states, + layer_idx, + retrieval_position_ids, + cos_sin_cache, + ): + # Update the cache + if self.prefill_layers == layer_idx: + # Prefill + self.prefill_kv_cache(value_states, layer_idx, key_states) + self.prefill_layers += 1 + if layer_idx == self.num_layers - 1: + self.prefill = False + self.prefill_layers = -1 + return key_states, value_states + else: + # Decode + self.update_kv_cache(key_states, value_states, layer_idx) + value_states = self.get_value_cache(layer_idx, retrieval_position_ids) + key_states = self.get_key_cache( + layer_idx, retrieval_position_ids, cos_sin_cache + ) + + return key_states, value_states + + def _reset_states(self): + self.k_cache_buffer.zero_() + self.v_cache_buffer.zero_() + self.selected_chunk_idx.zero_() + self.k_landmark = None + self.k_landmark_idx = None + self.U = None + self.SV = None + + self.kv_offset = 0 + self.prefill = 0 + self.gen_offset = 0 + self.prefill_local = 0 + + self.key_cache = [] + self.value_cache = [] + self._seen_tokens = 0 + self.prefill = True + self.prefill_layers = 0 + + def get_seq_length(self, layer_idx=0): + return self.kv_offset + + def get_svd(self, new_k_cache, layer_idx): + # [bsz, 8, prefill, 128] OR [bsz, prefill, 1024] + if new_k_cache.shape[1] <= 32: + # [bsz, 8, prefill, 128] --> [bsz, prefill, 1024] + k_cache = new_k_cache.transpose(1, 2).reshape( + self.batch_size, -1, self.num_key_value_heads * self.head_dim + ) + else: + # [bsz, prefill, 1024] + k_cache = new_k_cache + + if layer_idx == 0: + # init U, SV + self.U = torch.zeros( + self.num_layers, + self.batch_size, + k_cache.shape[1], + self.rank, + device=self.device, + dtype=self.dtype, + ) + self.SV = torch.zeros( + self.num_layers, + self.batch_size, + self.num_key_value_heads, + self.rank, + self.head_dim, + device=self.device, + dtype=self.dtype, + ) + + u, s, v = torch.svd(k_cache.float()) + v = v.transpose(1, 2) + # [bsz, 128k, 1024] --> [bsz, 128k, 160] [bsz, 160, 1024] (bsz, 8, 160, 128) + self.U[layer_idx].copy_(u[:, :, : self.rank].to(self.dtype)) # [bsz, 128k, 160] + self.SV[layer_idx].copy_( + torch.matmul(torch.diag_embed(s[:, : self.rank]), v[:, : self.rank]) + .to(self.dtype) + .view(self.batch_size, -1, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) # [bsz, 8, 160, 128] + + def register_k_landmark(self, k_landmark, k_landmark_idx, layer_idx): + num_landmarks = k_landmark.shape[-2] + if layer_idx == 0: + # init k_landmark, k_landmark_idx + self.k_landmark = torch.zeros( + self.num_layers, + self.batch_size, + self.num_key_value_heads, + num_landmarks, + self.head_dim, + device=self.device, + dtype=self.dtype, + ) + self.k_landmark_idx = torch.zeros( + self.num_layers, + self.batch_size, + self.num_key_value_heads, + num_landmarks, + device=self.device, + dtype=torch.long, + ) + + self.k_landmark[layer_idx].copy_(k_landmark.contiguous()) + self.k_landmark_idx[layer_idx].copy_(k_landmark_idx.contiguous()) + + def prefill_kv_cache( + self, + new_v_cache, + layer_idx, + key_states_roped, + ): + + incoming = new_v_cache.shape[-2] # [bsz, num_kv_heads, incoming, head_dim] + self.prefill = incoming + self.v_cache_cpu[layer_idx][:, :, :incoming] = new_v_cache.clone() + + # [x0, x1, ...., self.chunks*chunk_size, local_chunk, rest] + self.chunks = incoming // self.chunk_size - self.local_chunk + self.select_sets = self.sparse_budget // self.chunk_size + + assert ( + self.select_sets * self.chunk_size == self.sparse_budget + ), f'({self.select_sets}) * {self.chunk_size} != {self.sparse_budget}' + + # store Post-RoPE k cache to the cache + self.prefill_local = ( + incoming - self.chunks * self.chunk_size + ) # local chunks + align to chunk_size + self.k_cache_buffer[layer_idx][:, :, : self.prefill_local].copy_( + key_states_roped[:, :, -self.prefill_local:] + ) + self.v_cache_buffer[layer_idx][:, :, : self.prefill_local].copy_( + new_v_cache[:, :, -self.prefill_local:] + ) + + key_states_roped_ctx = key_states_roped[ + :, :, : self.chunks * self.chunk_size + ].view( + self.batch_size, + self.num_key_value_heads, + self.chunks, + self.chunk_size, + self.head_dim, + ) + landmark_candidates = key_states_roped_ctx.mean( + dim=-2 + ) # [bsz, kv_heads, chunks, head_dim] + + # compute the cos similarity between it and the original key cache + cos_sim = torch.nn.functional.cosine_similarity( + landmark_candidates.unsqueeze(3).expand(-1, -1, -1, self.chunk_size, -1), + key_states_roped_ctx, + dim=-1, + ) # [bsz, kv_heads, chunks, chunk_size] + + # get the outlier_chunk idx for each head # [bsz, kv_heads, outlier_chunk] + outlier_chunk_idx = ( + cos_sim.min(dim=-1).values.topk(self.outlier_chunk, largest=False).indices + ) + + outlier_chunk_k_cache = key_states_roped_ctx.gather( + dim=2, + index=outlier_chunk_idx.unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, -1, -1, self.chunk_size, self.head_dim), + ).view( + self.batch_size, + self.num_key_value_heads, + self.outlier_chunk * self.chunk_size, + self.head_dim, + ) + + outlier_chunk_v_cache = ( + new_v_cache[:, :, : self.chunks * self.chunk_size] + .view( + self.batch_size, + self.num_key_value_heads, + self.chunks, + self.chunk_size, + self.head_dim, + ) + .gather( + dim=2, + index=outlier_chunk_idx.unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, -1, -1, self.chunk_size, self.head_dim), + ) + .view( + self.batch_size, + self.num_key_value_heads, + self.outlier_chunk * self.chunk_size, + self.head_dim, + ) + ) + + self.sparse_start = self.prefill_local + self.outlier_chunk * self.chunk_size + self.sparse_end = ( + self.prefill_local + + self.outlier_chunk * self.chunk_size + + self.sparse_budget + ) + + # store outlier_chunk to the cache + self.k_cache_buffer[layer_idx][ + :, :, self.prefill_local: self.sparse_start + ].copy_(outlier_chunk_k_cache) + self.v_cache_buffer[layer_idx][ + :, :, self.prefill_local: self.sparse_start + ].copy_(outlier_chunk_v_cache) + + # filter landmark_candidates using outlier_chunk and register the rest to k_landmark + # [bsz, kv_heads, chunks, head_dim] --> [bsz, kv_heads, chunks - outlier_chunk, head_dim] + # get rest_idx: [bsz, kv_heads, chunks] --filter--> [bsz, kv_heads, chunks - outlier_chunk] + all_idx = ( + torch.arange(self.chunks, device=key_states_roped.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(self.batch_size, self.num_key_value_heads, -1) + ) # [bsz, kv_heads, chunks] + mask = torch.ones_like(all_idx, dtype=torch.bool) + mask.scatter_(dim=-1, index=outlier_chunk_idx, value=False) + rest_idx = all_idx.masked_select(mask).view( + self.batch_size, self.num_key_value_heads, -1 + ) + + # register rest_idxed landmarks to k_landmark + self.register_k_landmark( + landmark_candidates.gather( + dim=2, index=rest_idx.unsqueeze(-1).expand(-1, -1, -1, self.head_dim) + ).view(self.batch_size, self.num_key_value_heads, -1, self.head_dim), + rest_idx, + layer_idx, + ) + + if layer_idx == self.num_layers - 1: + assert self.sparse_budget < incoming + self.kv_offset += incoming + + def get_retrieval_position_ids(self, layer_idx, query_states): + # self.k_landmark[layer_idx][:, :, :self.chunks] is [bsz, 8, chunks, head_dim] + # chunk_attn: [bsz, 32, window_size, chunks] + self.incoming_q_len = query_states.shape[-2] # 1 + # [bsz, 8, 4, q_len, 128] * [bsz, 8, 128, chunks] --> [bsz, 8, 4, q_len, chunks] + chunk_attn = torch.einsum( + 'bhgqd,bhdc->bhgqc', + query_states.view( + -1, + self.num_key_value_heads, + self.num_key_value_groups, + self.incoming_q_len, + self.head_dim, + ), + self.k_landmark[layer_idx].transpose(2, 3), + ).squeeze(2) / math.sqrt(128) + chunk_attn = nn.functional.softmax(chunk_attn, dim=-1, dtype=torch.float32).to( + self.dtype + ) # [bsz, 8, 4, q_len, chunks] + chunk_attn = chunk_attn.sum(dim=-2) # [bsz, 8, 4, chunks] + if self.num_key_value_groups > 1: + chunk_attn, _ = torch.max(chunk_attn, dim=-2) # [bsz, 8, chunks] + + merged_results = torch.topk( + chunk_attn, k=self.select_sets, dim=-1 + ).indices # [bsz, 8, select_sets(256)] + + # use merged_results to gather the position_ids: + # [bsz, 8, select_sets] --> [bsz, 8, select_sets] + selected_chunks = self.k_landmark_idx[layer_idx].gather( + dim=-1, index=merged_results + ) # [bsz, 8, select_sets] + + # this is chunk idx, which can be used to offload value cache and decide if the cache hits + self.selected_chunk_idx[layer_idx].copy_(selected_chunks, non_blocking=True) + + position_ids = ( + selected_chunks.unsqueeze(-1) * self.chunk_size + + torch.arange(self.chunk_size, device=chunk_attn.device) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + ).view( + self.batch_size, self.num_key_value_heads, -1 + ) # [bsz, 8, select_sets * chunk_size] + + return position_ids + + def get_value_cache(self, layer_idx, retrieval_position_ids): + # gather value cache + value_ = self.v_cache_cpu[layer_idx].gather( + dim=-2, + index=retrieval_position_ids.unsqueeze(-1).expand( + -1, -1, -1, self.head_dim + ), + ) + self.v_cache_buffer[layer_idx][:, :, self.sparse_start: self.sparse_end].copy_( + value_, non_blocking=True + ) + gen_offset = ( + self.gen_offset + if layer_idx == self.num_layers - 1 + else self.gen_offset + self.incoming_q_len + ) + + return self.v_cache_buffer[layer_idx][:, :, : self.sparse_end + gen_offset] + + def get_key_cache(self, layer_idx, retrieval_position_ids, cos_sin_cache): + # gather key cache and rope them + u = self.U[layer_idx] # [bsz, 128k, rank] + sv = self.SV[layer_idx] # [bsz, 8, rank, 128] + + # indexing, [bsz, 8, sparse_budget, rank] + index_expanded = retrieval_position_ids.unsqueeze(-1).expand( + -1, -1, -1, u.size(-1) + ) # [bsz, 8, sparse_budget, rank] + u_expand = u.unsqueeze(1).expand( + -1, self.num_key_value_heads, -1, -1 + ) # [bsz, 8, 128k, rank] + U_head = torch.gather(u_expand, 2, index_expanded) + + # [bsz, 8, sparse_budget, rank] -matmul- [8, rank, 128] --> [bsz, 8, sparse_budget, 128] + result = torch.einsum('bhrk,bhkd->bhrd', U_head, sv) + + # # rope the key cache + cos, sin = cos_sin_cache + result = apply_rotary_pos_emb_single(result, cos, sin, retrieval_position_ids) + + # send to buffer + self.k_cache_buffer[layer_idx][:, :, self.sparse_start: self.sparse_end].copy_( + result, non_blocking=True + ) + gen_offset = ( + self.gen_offset + if layer_idx == self.num_layers - 1 + else self.gen_offset + self.incoming_q_len + ) + + return self.k_cache_buffer[layer_idx][:, :, : self.sparse_end + gen_offset] + + def update_kv_cache( + self, + new_k_cache: torch.Tensor, + new_v_cache: torch.Tensor, + layer_idx: int, + ): + + incoming = new_k_cache.shape[-2] + self.v_cache_buffer[layer_idx][ + :, + :, + self.sparse_end + + self.gen_offset: self.sparse_end + + self.gen_offset + + incoming, + ].copy_(new_v_cache, non_blocking=True) + self.k_cache_buffer[layer_idx][ + :, + :, + self.sparse_end + + self.gen_offset: self.sparse_end + + self.gen_offset + + incoming, + ].copy_(new_k_cache, non_blocking=True) + + if layer_idx == self.num_layers - 1: + self.kv_offset += incoming + self.gen_offset += incoming + + +@KV_REGISTRY.register('SinkKV') +class SinkKVCache(DynamicCache): + def __init__( + self, + num_hidden_layers, + window_length, + num_sink_tokens, + ): + super().__init__() + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + + @staticmethod + def _rotate_half(x): + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states, cos, sin + ): + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states, cos, sin + ): + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + original_cos = cos[self.num_sink_tokens + key_states.shape[-2]:] + shifted_cos = cos[self.num_sink_tokens:-key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2]:] + shifted_sin = sin[self.num_sink_tokens:-key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx=0): + """Returns the sequence length of the cached states. + + A layer index can be optionally passed. + """ + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self): + """Returns the maximum sequence length of the cache object, in case of + SinkCache it is the window length.""" + return self.window_length + + def update( + self, + key_states, + value_states, + layer_idx, + cache_kwargs, + ): + + sin = cache_kwargs.get('sin') + cos = cache_kwargs.get('cos') + partial_rotation_size = cache_kwargs.get('partial_rotation_size') + using_rope = cos is not None and sin is not None + + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if using_rope and layer_idx == 0: + + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = \ + torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = \ + torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2]: + ] + + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length] + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, + rerotation_cos, + rerotation_sin) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2]: + ] + + self.value_cache[layer_idx] = torch.cat([sink_values, + values_to_keep, + value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def _reset_states(self): + self.key_cache = [] + self.value_cache = [] + self._seen_tokens = 0 diff --git a/llmc/compression/sparsification/shortgpt.py b/llmc/compression/sparsification/shortgpt.py index c8c8dc41..9684e19d 100644 --- a/llmc/compression/sparsification/shortgpt.py +++ b/llmc/compression/sparsification/shortgpt.py @@ -19,6 +19,7 @@ class ShortGPT(BaseBlockwiseSparsification): def __init__(self, model, sparsity_config, input, padding_mask, config): super().__init__(model, sparsity_config, input, padding_mask, config) + self.importances = np.zeros(len(self.blocks)) def block_opt(self, block): block = block.cuda() @@ -60,9 +61,7 @@ def subset_transform( output_feat ): # calculate BI score - if self.sparser.importances is None: - self.sparser.importances = np.zeros(len(self.blocks)) - self.sparser.importances[self.block_idx] = self.compute_bi( + self.importances[self.block_idx] = self.compute_bi( input_feat[0], output_feat[0] ).sum().cpu().item() @@ -71,10 +70,10 @@ def remove_layers( self, layers_to_remove: Optional[List[int]] = [] ): - if not layers_to_remove and self.sparser.n_prune_layers: + if not layers_to_remove and self.n_prune_layers: layers_to_remove = np.argsort( - np.array(self.sparser.importances) - )[:self.sparser.n_prune_layers].tolist() + np.array(self.importances) + )[:self.n_prune_layers].tolist() for idx in sorted(layers_to_remove, reverse=True): try: @@ -85,7 +84,7 @@ def remove_layers( @torch.no_grad() def deploy(self, deploy_format): - logger.info(f'After compute, BI scores are {self.sparser.importances}') + logger.info(f'After compute, BI scores are {self.importances}') logger.info('-- deploy_sparsity_model start --') logger.info(f'sparsity_config : {self.sparsity_config}') logger.info('-- begin remove layers --') diff --git a/llmc/compression/sparsification/sparse.py b/llmc/compression/sparsification/sparse.py deleted file mode 100644 index 09553572..00000000 --- a/llmc/compression/sparsification/sparse.py +++ /dev/null @@ -1,9 +0,0 @@ -class Sparser: - def __init__(self, sparsity_constraint, **kwargs): - if 'sparsity' in sparsity_constraint: - self.sparsity = sparsity_constraint['sparsity'] - self.W_mask = None - elif 'n_prune_layers' in sparsity_constraint: - self.n_prune_layers = sparsity_constraint['n_prune_layers'] - self.importances = None - self.kwargs = kwargs diff --git a/llmc/eval/eval_ppl.py b/llmc/eval/eval_ppl.py index fe1613bd..d598218c 100644 --- a/llmc/eval/eval_ppl.py +++ b/llmc/eval/eval_ppl.py @@ -32,7 +32,6 @@ def eval_func(self, model, testenc, seq_len, bs, eval_pos): lm_logits = model.model(inputs).logits model.reset_kv() - # Shift logits and labels for next token prediction shift_logits = lm_logits[:, :-1, :].contiguous() shift_labels = inputs[:, 1:] diff --git a/llmc/eval/utils.py b/llmc/eval/utils.py index 26af99fb..c6afae1a 100644 --- a/llmc/eval/utils.py +++ b/llmc/eval/utils.py @@ -79,7 +79,7 @@ def eval_model(model, blockwise_opt, eval_list, eval_pos): if do_eval: if eval_pos == 'transformed': blockwise_opt.deploy('origin_float') - elif eval_pos == 'fake_quant': + elif eval_pos in ['fake_quant', 'fake_quant_wo_kv']: blockwise_opt.deploy('fake_quant') for eval_class, config_for_eval in eval_list: if eval_pos in config_for_eval.eval.eval_pos: From 620cf410d740989580507f46341b5c9d2324119f Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 13 Jan 2025 15:33:23 +0800 Subject: [PATCH 2/3] Support ShadowKV and fix bugs --- configs/quantization/backend/vllm/fp8/gptq_fp8.yml | 4 ++-- configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml | 2 -- llmc/compression/quantization/quant.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/configs/quantization/backend/vllm/fp8/gptq_fp8.yml b/configs/quantization/backend/vllm/fp8/gptq_fp8.yml index ec5db89f..905be88a 100644 --- a/configs/quantization/backend/vllm/fp8/gptq_fp8.yml +++ b/configs/quantization/backend/vllm/fp8/gptq_fp8.yml @@ -27,14 +27,14 @@ eval: quant: method: GPTQ weight: - quant_type: float_quant + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_channel use_qtorch: True act: - quant_type: float_quant + quant_type: float-quant # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True diff --git a/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml b/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml index 0e105a70..237c2118 100644 --- a/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml +++ b/configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml @@ -17,7 +17,6 @@ eval: quant: method: RTN weight: - quant_type: int-quant bit: 48 bit4: symmetric: False @@ -32,7 +31,6 @@ quant: granularity: per_channel int_range: [-120, 120] act: - quant_type: int-quant bit: 8 symmetric: True granularity: per_token diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 3d2192fe..17aa0981 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1250,7 +1250,7 @@ def __init__(self, bit, bit4, bit8, **kwargs): self.bit_settings[bit]['zeros_qmin'] = self.bit_settings[bit]['qmin'] self.bit_settings[bit]['zeros_qmax'] = self.bit_settings[bit]['qmax'] - def reshape_tensor(self, tensor, bit): + def reshape_tensor(self, tensor, bit=4): granularity = self.bit_settings[bit].get('granularity') if granularity == 'per_group': group_size = self.bit_settings[bit].get('group_size') From 1ee45d0eb7c8e23b90bd8765d827f86a9d6f468d Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 13 Jan 2025 15:40:09 +0800 Subject: [PATCH 3/3] Support ShadowKV and fix bugs --- configs/quantization/backend/vllm/fp8/awq_fp8_static.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml b/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml index 1f09e77c..c4542507 100644 --- a/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml +++ b/configs/quantization/backend/vllm/fp8/awq_fp8_static.yml @@ -34,8 +34,8 @@ quant: granularity: per_tensor use_qtorch: True act: - # Support ["e4m3", "e5m2"] quant_type: float-quant + # Support ["e4m3", "e5m2"] bit: e4m3 symmetric: True granularity: per_tensor