Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in the file 2:4_w4a16_group-128_recipe.yaml #154

Open
carrot-o0o opened this issue Sep 10, 2024 · 0 comments
Open

Error in the file 2:4_w4a16_group-128_recipe.yaml #154

carrot-o0o opened this issue Sep 10, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@carrot-o0o
Copy link

carrot-o0o commented Sep 10, 2024

Describe the bug
This is a minor issue, but I think the quantization configuration in the file [examples/quantization_24_sparse_w4a16/2:4_w4a16_group-128_recipe.yaml](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_24_sparse_w4a16/2%3A4_w4a16_group-128_recipe.yaml) should include ignore: ["lm_head"] like below. Otherwise, during saving the quantized model, the code results in a ValueError caused by compressed_tensors because the lm_head doesn't follow the 2:4 sparse pattern.

quantization_stage:
  run_type: oneshot
  quantization_modifiers:
    GPTQModifier:
      sequential_update: false
      ignore: ["lm_head"]
      config_groups:
        group_0:
          weights:
            num_bits: 4
            type: "int"
            symmetric: true
            strategy: "channel"
          targets: ["Linear"]

Expected behavior
A clear and concise description of what you expected to happen.

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 20.04]: 22.04
  2. Python version [e.g. 3.7]: 3.10
  3. LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: 7a0d232
  4. ML framework version(s) [e.g. torch 2.3.1]: 2.4.0
  5. Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]: compressed-tensors 0.5.0
  6. Other relevant environment information [e.g. hardware, CUDA version]: CUDA 12.3

To Reproduce
Exact steps to reproduce the behavior:
i ran python examples/quantization_24_sparse_w4a16/llama7b_sparse_w4a16.py, where I changed the model path to another Llama model and the recipe path to 2:4_w4a16_group-128_recipe.yaml

Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.

The Error without ignore lm_head:

Traceback (most recent call last):
  File "~/projects/sparse/llm-compressor/examples/quantization_24_sparse_w4a16/llama7b_sparse_w4a16.py", line 40, in <module>
    apply(
  File "~/projects/sparse/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 93, in apply
    main(model_args, data_args, training_args)
  File "~/sparse/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 348, in main
    stage_runner.run_sequential_stages(checkpoint)
  File "~/projects/sparse/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 291, in run_sequential_stages
    self.one_shot(stage=stage_name)
  File "~/projects/sparse/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 194, in one_shot
    save_model_and_recipe(
  File "~/projects/sparse/llm-compressor/src/llmcompressor/pytorch/model_load/helpers.py", line 110, in save_model_and_recipe
    model.save_pretrained(
  File "~/projects/sparse/llm-compressor/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py", line 123, in save_pretrained_wrapper
    compressed_state_dict = compressor.compress(model, state_dict)
  File "~/conda/llm-compresser/lib/python3.10/site-packages/compressed_tensors/compressors/model_compressor.py", line 241, in compress
    compressed_state_dict = self.quantization_compressor.compress(
  File "~/conda/llm-compresser/lib/python3.10/site-packages/compressed_tensors/compressors/marlin_24.py", line 149, in compress
    self.validate_sparsity_structure(prefix, value)
  File "~/conda/llm-compresser/lib/python3.10/site-packages/compressed_tensors/compressors/marlin_24.py", line 99, in validate_sparsity_structure
    if not tensor_follows_mask_structure(weight):
  File "~/conda/llm-compresser/lib/python3.10/site-packages/compressed_tensors/utils/helpers.py", line 91, in tensor_follows_mask_structure
    raise ValueError()
ValueError

Additional context
Add any other context about the problem here. Also include any relevant files.

@carrot-o0o carrot-o0o added the bug Something isn't working label Sep 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant