Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] [QUANT] Support for GPTQModel's dynamic quantization per module override/control #7086

Open
wants to merge 58 commits into
base: main
Choose a base branch
from

Conversation

Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Aug 2, 2024

GPTQModel v0.9.10-dev0 main branch has merged dynamic, per layer/module support of different gptq bits, sym, desc_act using a regex style definition. This is a work in process and we are awaiting feedback before release. We are targeting both vllm and sglang compat with the quant so would like to work with vllm to see if what is the best way forward.

Previously a gptq model has a single config that applies to all layers and all modules within nested layers. This change allows pin-point targeting of different gptq quantization config for specific layers and/or specific modules within specific layers for better optimization.

Sample model: https://huggingface.co/ModelCloud/TinyLlama-1.1B-Chat-v1.0-dynamic-GPTQ-2024-8-3

full quant config for sample:

{
  "bits": 4,
  "dynamic": {
    ".*\\.(?:1[0-5])\\..*": {
      "bits": 8
    },
    ".*\\.(?:1[6-9]|20|21)\\..*": {
      "bits": 8,
      "group_size": 64
    }
  },
  "group_size": 128,
  "desc_act": true,
  "static_groups": false,
  "sym": true,
  "lm_head": false,
  "damp_percent": 0.005,
  "damp_auto_increment": 0.0015,
  "true_sequential": true,
  "model_name_or_path": "./test_dynamic_model",
  "model_file_base_name": "model",
  "quant_method": "gptq",
  "checkpoint_format": "gptq",
  "meta": {
    "quantizer": "gptqmodel:0.9.10-dev0"
  }
}

Dynamic config explained:

# sample tinyllama 1.1B model has 22 layers
# default is 4bit, group_size 128
# layer index start at 0

# last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# last 1/4 of the layers 16-21 has 8bit and group_size 64
dynamic = {
  # `.*\.` matches the layers_node prefix
  r".*\.(?:1[0-5])\..*": {"bits": 8,}, # match layer 10-15
  r".*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # match layer 16-21
}

Same code to quantize using dynamic control: https://github.com/ModelCloud/GPTQModel/blob/main/tests/test_dynamic.py

Design choices:

  1. Need a def table to notify quantizer (GPTQModel) and infer engine (vllm) which layers has dynamic (override) quant config.
  2. Possible to generate a static all inclusive per layer/module def/table in json but content would not be human friendly as each nested layer with each nested module would need an entry. If a model has 44 layers and each layer has 6-8 modules, we are looking a t a 44x8 lines of json minimum.
  3. GPTQModel decided on a design where a simple regex: str key mapped to dict[str, int or bool] for both quantization and model inference/loading. Multiple regex/dynamic pairs can be defined and for matching, the rules are looped and first one that match, is applied.
  4. Upload loading, and looping over each layer/module, we check for dynamic (override) match and if matches, override the static quant config files for that layer/module.

Compat Notes:

dynamic config require that the model inference does not remerge the layers with different dyanmic/quant param values. MergedColumnParallel in Llama model in vllm for example merges mlp.gate and mlp.up. Dynamic override works but in this case, because they are fused/merged, these two layers must have exact same quant config values. Can't have one with 4bit and the other with 8bits.

TODO:

  1. unit test
  2. finalize design of loading so GPTQModel and vllm can agree, how to best pass/share dynamic layer/module quant override config via quantize_config json

Copy link

github-actions bot commented Aug 2, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mgoin
Copy link
Member

mgoin commented Aug 5, 2024

Hi @Qubitium thanks for sharing your interesting work!

We have the notion of variable quantization already in vLLM through our compressed-tensors integration. With this we can blend integer and float quantization of weights and/or activations within a single model config in a similar explicit target or regex manner. I recommend digging into the non-uniform support we already have for compressed-tensors and fp8 methods.
It would be interesting to see if your library could export into compressed-tensors format so it would work out-of-the-box in vLLM and Transformers!

Regarding merged layers, I think the performance and complexity cost of needing to support possibly unmerging layers like QKV or GateUp is too high. I want to recommend keeping the quantization level of merged layers the same so we (and several other inference engines) don't run into this issue.

If you are still open to editing your format, I also think dynamic isn't a clear term here since there is already the notion of static or dynamic quantization, which means something else. Also, the quantization isn't changing in any dynamic way. I would recommend using a name like non-uniform quantization, since we are not performing uniform quantization anymore but have settled on a non-uniform scheme.

@Qubitium
Copy link
Contributor Author

Qubitium commented Aug 5, 2024

@mgoin Wow, I totally missed this PR. After cursory check of the https://github.com/vllm-project/vllm/pull/6515/files PR, our pr is entirely redundant. The core concept is similar including re matching. The only little advantage of this pr, and very little at this point, is minimal code-change to bootstrap gptq flexible layer/module quant.

I will need to digest the vllm pr/unit tests to test with gptqmodel export. If gptq model can integrate with compressed_config protocol, then there is zero reason for this pr.

Regarding merged layers, I think the performance and complexity cost of needing to support possibly unmerging layers like >QKV or GateUp is too high. I want to recommend keeping the quantization level of merged layers the same so we (and >several other inference engines) don't run into this issue.

Yes, this our finding as well. Merged layers should retain the same scheme.

If you are still open to editing your format, I also think dynamic isn't a clear term here since there is already the notion of >static or dynamic quantization, which means something else. Also, the quantization isn't changing in any dynamic way. I >would recommend using a name like non-uniform quantization, since we are not performing uniform quantization ?>anymore but have settled on a non-uniform scheme.

I want the config to to be compatible to vllm/sglang, and since sglang for the most part re-uses/import vllm model weight/model layers. Do not want another protocol parser so if vllm compressed_config protocol works like I think it does, then this is good base moving forward for gptqmodel as well.

Copy link

mergify bot commented Dec 24, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Qubitium.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 24, 2024
@mergify mergify bot removed the needs-rebase label Dec 24, 2024
@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 7, 2025

@mgoin We are working on the lm_head portion to get it ready for today.

gptqmodel has been merged into hf transformers, optimum, peft, and hf has agreed in principle with us to start deprecating autogptq in the near future.

We also have pending PR, slated for next week, that we are actively working with Nvidia staff which will introduce another gptq config for lora layer optimized for gptq.

What I am trying to say is, gptq quantized model, that is quantized via hf optimum or gptqmodel will use gptqmodel standard config and the config is expanding as features expand.

Can we be free to sync gptqconfig in vllm to gptqmodel format and not have deviations such as dynamic vs dynamic_cfg? For this PR, the dynamic_cfg field is the issue.

@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 7, 2025

@mgoin We are working on the lm_head portion to get it ready for today.

gptqmodel has been merged into hf transformers, optimum, peft, and hf has agreed in principle with us to start deprecating autogptq in the near future.

We also have pending PR, slated for next week, thar we are actively working with Nvidia staff which will introduce another gptq config for lora layer optimized for gptq.

What I am trying to say is, gptq quantized model, that is quantized via hf optimum or gptqmodel will use gptqmodel standard config and the config is expanding as features expand. Hf and gptqmodel are not the only tools to generate gptq but we are the only actively maintained project that is exclusive gptq.

Can we be free to sync gptqconfig in vllm to gptqmodel format and not have deviations such as dynamic vs dynamic_cfg? After going over this PR, this var name sticks out to me.

For ctx, dynamic field was renamed to dynamic_cfg in earlier phase of the pr due to review feedback as dynamic doesnt actually mean runtime dynamism (if there is such a word) and in the ctx of vllm inference is nothing dynamic about it but more of a static override of configd per module.

@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 7, 2025

@mgoin PR is ready for re-review but we are incapable of fixing the missing sign-off requirements. Pwned. We are testing various methods but the code itself is fine and good for review.

@Qubitium Qubitium requested a review from mgoin February 7, 2025 08:58
@jeejeelee
Copy link
Collaborator

DCO has passed

@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 7, 2025

DCO has passed

How did you get it to pass?

@jeejeelee
Copy link
Collaborator

@Qubitium commiter can fix DCO failure directly

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah don't worry about the DCO, we can signoff right before merging and it doesn't block anything

assert isinstance(lm_head_layer.linear_method,
assert isinstance(lm_head_layer.quant_method,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be best to keep linear_method for both lm_head and embedding layer, since calling it quant_method doesn't make sense for the base case of unquantized methods. While I agree linear isn't perfect for embeddings, there isn't a strong reason to change it in this PR.

Copy link
Contributor Author

@Qubitium Qubitium Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Sorry we didn't make it clear in our notes but there is reason for this change now that you mentioned it.

Please check https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/loader.py#L398

We believe this is a bug fix of existing lm_head quantized init code. If lm_head is quantized using sym=False, as the ci-test was written by (me and @robertgshaw2-redhat) but using a model quantized (sym=False) by me, it can't route to Marlin kernel since Marlin doesn't support sym=False. Ci-test passes because it gets routed to fall-back cuda kernel. Without this fix, lm_head quantized that is compatible with Marlin kernel code will crash since it checks for quant_method attribute for correct Marlin init. So we synced lm_head attr to be same name as other modules to fix following crash:

Mode: https://huggingface.co/ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bits-dynamic-cfg-with-lm_head

tests/conftest.py:682: in __init__
    self.model = LLM(
vllm/utils.py:1051: in inner
    return fn(*args, **kwargs)
vllm/entrypoints/llm.py:242: in __init__
    self.llm_engine = self.engine_class.from_engine_args(
vllm/engine/llm_engine.py:484: in from_engine_args
    engine = cls(
vllm/engine/llm_engine.py:276: in __init__
    self._initialize_kv_caches()
vllm/engine/llm_engine.py:416: in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
vllm/executor/executor_base.py:101: in determine_num_available_blocks
    results = self.collective_rpc("determine_num_available_blocks")
vllm/executor/uniproc_executor.py:51: in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
vllm/utils.py:2220: in run_method
    return func(*args, **kwargs)
/root/miniconda3/envs/gp/lib/python3.11/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
vllm/worker/worker.py:229: in determine_num_available_blocks
    self.model_runner.profile_run()
/root/miniconda3/envs/gp/lib/python3.11/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
vllm/worker/model_runner.py:1235: in profile_run
    self._dummy_run(max_num_batched_tokens, max_num_seqs)
vllm/worker/model_runner.py:1346: in _dummy_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
/root/miniconda3/envs/gp/lib/python3.11/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
vllm/worker/model_runner.py:1765: in execute_model
    logits = self.model.compute_logits(hidden_or_intermediate_states,
vllm/model_executor/models/qwen2.py:496: in compute_logits
    logits = self.logits_processor(self.lm_head, hidden_states,
/root/miniconda3/envs/gp/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/root/miniconda3/envs/gp/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
vllm/model_executor/layers/logits_processor.py:74: in forward
    logits = self._get_logits(hidden_states, lm_head, embedding_bias)
vllm/model_executor/layers/logits_processor.py:111: in _get_logits
    logits = lm_head.quant_method.apply(lm_head,
vllm/model_executor/layers/quantization/gptq_marlin.py:406: in apply
    return self.kernel.apply_weights(layer, x, bias)
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py:129: in apply_weights
    g_idx_sort_indices=layer.g_idx_sort_indices,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ParallelLMHead(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1), name = 'g_idx_sort_indices'

    def __getattr__(self, name: str) -> Any:
        if "_parameters" in self.__dict__:
            _parameters = self.__dict__["_parameters"]
            if name in _parameters:
                return _parameters[name]
        if "_buffers" in self.__dict__:
            _buffers = self.__dict__["_buffers"]
            if name in _buffers:
                return _buffers[name]
        if "_modules" in self.__dict__:
            modules = self.__dict__["_modules"]
            if name in modules:
                return modules[name]
>       raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{name}'"
        )
E       AttributeError: 'ParallelLMHead' object has no attribute 'g_idx_sort_indices'

Copy link
Contributor Author

@Qubitium Qubitium Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we revert to linear_method, we can bypass the bug by calling the process_weights_after_loading manually. But using quant_method attribute would allow existing code to do this without extra intervensions. It seems like a cleaner way since lm_head is treated like other modules vs doing it's own thing.

Comment on lines 167 to 254
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def __init__(self, quant_config: GPTQMarlinConfig, prefix: str) -> None:
self.quant_config = deepcopy(quant_config)
self.prefix = prefix

if len(self.quant_config.dynamic_cfg) > 0 and self.prefix:
# gptqmodel per module/layer dynamic_cfg my override/change base
# model quant config
self.quant_config.override_config(prefix=self.prefix)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you have enough information do perform this same quant_config copy and override in get_quant_method, so why not keep the dynamism within that function where you already have the prefix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Yes! We will push this override outside of __init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Fixed.

ZX-ModelCloud and others added 3 commits February 7, 2025 16:09
Signed-off-by: ZX-ModelCloud <[email protected]>
Signed-off-by: ZX-ModelCloud <[email protected]>
@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 7, 2025

@mgoin Ready for re-review. There were some small clarity fixes in terms of logic expression, var name, comments committed since your last review in addition to the requested change about moving config override to outside of marlin __init__.

@Qubitium Qubitium changed the title [CORE] [QUANT] Support for GPTQModel's dynamic config control per module/layer [CORE] [QUANT] Support for GPTQModel's dynamic quantization per module override/control Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants