Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Mamba2 (Codestral Mamba) #9292

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth commented Oct 11, 2024

Add support for Mamba2. Not thoroughly tested yet, but Codestral Mamba has legible outputs.

Todo:

  • Integration tests
  • Support Chunked Prefill
  • Incorporate mamba_chunk_scan_combined kernel to avoid the dependency on mamba_ssm
  • Fix tensor parallelism
  • Try to refactor the code for Mamba2's mixer layer to look more like Mamba's

Closes #6479

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@tlrmchlsmth
Copy link
Collaborator Author

Notes on current state:

  1. Now that [Kernel][Model] Improve continuous batching for Jamba and Mamba #9189 has landed, need to update the mamba_chunk_scan_combined to take cache indices, so that this PR will work with the updated MambaCacheManager. Until then this PR is not compatible with current main.
  2. TP does seem to work in the present state however I see bad output when using CUDA graphs + custom_all_reduce

@yury-tokpanov
Copy link

@tlrmchlsmth thank you very much for your work! Kindly asking, do you have any updates since your last post?

@tlrmchlsmth
Copy link
Collaborator Author

hi @yury-tokpanov, sorry, I am focusing on other things right now -- namely support for tensor parallelism in V1. Currently this PR need some fairly hard debugging. Hopefully I'll be able to resume working on this PR once I finish that work.

@yury-tokpanov
Copy link

@tlrmchlsmth Thanks for the update! I work at Zyphra, and we are interested in incorporating our Zamba2 model into vLLM (#9382). I'm using your PR as a starting point, since we need mamba2 layers for that. If you're open to it, I'd be happy to help with finishing this PR.

@tlrmchlsmth
Copy link
Collaborator Author

Hi @yury-tokpanov, yes I would be very open to that! If you need any pointers/advice/help please feel free to reach out on the vllm developer slack (https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack)

@fabianlim fabianlim mentioned this pull request Dec 5, 2024
6 tasks
@fabianlim
Copy link
Contributor

@tlrmchlsmth @yury-tokpanov We also recently opened a PR to add a new model (Bamba) that also requires mamba v2 support. For continuous batching it works, but supporting chunked prefill can be quite challenging.

cc: @ani300, @raghukiran1224, @njhill

tlrmchlsmth and others added 2 commits January 16, 2025 17:49
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Copy link

mergify bot commented Jan 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tlrmchlsmth.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 16, 2025
@tlrmchlsmth tlrmchlsmth force-pushed the tms/mamba2 branch 2 times, most recently from bc9b5cf to 17923ad Compare January 16, 2025 22:16
@mergify mergify bot removed the needs-rebase label Jan 16, 2025
tlrmchlsmth and others added 7 commits January 17, 2025 02:03
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth tlrmchlsmth marked this pull request as ready for review January 20, 2025 23:36
Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth
Copy link
Collaborator Author

Note that at this point, much of the implementation is now taken directly from #10909

@yury-tokpanov
Copy link

@tlrmchlsmth @fabianlim thanks for all your work! I have our internal implementation of Zamba2 based of previous version of this PR. I'm going to rebase it. Would you recommend using this branch or the one from bamba PR #10909?

assert not is_lora_enabled

self.config = config
self.padding_idx = config.pad_token_id

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this one used anywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into cleaning this up. I see several models that have similar seemingly unused padding_idx variables

Copy link

mergify bot commented Jan 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tlrmchlsmth.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 24, 2025
# For eager just take the scheduler_config if avail
self.max_batch_size = self.scheduler_config.max_num_seqs
else:
self.max_batch_size = 128 + 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason for setting max_batch_size at a much lower value than in mamba or jamba?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to avoid out-of-memory issues that I was seeing with CodestralMamba

Before landing I will check if we can simplify this code - I'm not sure when self.scheduler_config would be None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt do a very careful comparison, but given that mamba2's design is to have higher headdim, the cache size of mamba2 should be much higher than that of mamba1. Hence we should not be expecting to allocate caches for same number of seqs as in mamba1 I believe.

@tlrmchlsmth
Copy link
Collaborator Author

@yury-tokpanov I'd recommend using the mixer2 implementation from this PR -- I think there are only one or two small changes from the Bamba PR but notably I did fix one correctness issue with a contiguous call that isn't in the Bamba PR

Signed-off-by: Tyler Michael Smith <[email protected]>
@mergify mergify bot removed the needs-rebase label Jan 27, 2025
@DarkLight1337
Copy link
Member

Heads up that Mamba tests currently fail on main: #12465

It would be great if you could solve the issue in this PR as well!

@tlrmchlsmth
Copy link
Collaborator Author

tlrmchlsmth commented Jan 27, 2025

Heads up that Mamba tests currently fail on main: #12465

It would be great if you could solve the issue in this PR as well!

I'm looking into it. I see the following error in the logs, but so far am unable to reproduce this issue on an H100

[2025-01-14T06:40:30Z] FAILED models/decoder_only/language/test_mamba.py::test_mamba_cache_cg_padding[20-bfloat16-tiiuae/falcon-mamba-tiny-dev] - AssertionError: Error in memory profiling. Initial free memory 22215917568, current free memory 22215917568. This happens when the GPU memory was not properly cleaned up before initializing the vLLM instance.

Edit: fixed error message

@tlrmchlsmth
Copy link
Collaborator Author

The other update I have is that I am trying to reproduce the humaneval results reported here, using https://github.com/neuralmagic/evalplus, but no luck so far:

humaneval (base tests)
pass@1: 0.220
pass@10:        0.406
humaneval+ (base + extra tests)
pass@1: 0.190
pass@10:        0.360

repro steps:

python codegen/generate.py --model mistralai/Mamba-Codestral-7B-v0.1 --bs 16 --temperature 0.2 --n_samples 50 --root "./results" --dataset humaneval --backend hf --dtype auto --tp 1
 
python evalplus/sanitize.py results2/humaneval/mistralai--Mamba-Codestral-7B-v0.1_vllm_temp_0.2/

evalplus.evaluate --dataset humaneval --samples results/humaneval/mistralai--Mamba-Codestral-7B-v0.1_vllm_temp_0.2-sanitized

@tlrmchlsmth
Copy link
Collaborator Author

GSM8k results:

vllm:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2244|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.2995|±  |0.0126|


hf:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2305|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.2949|±  |0.0126|

@yury-tokpanov
Copy link

yury-tokpanov commented Jan 27, 2025

I am unable to reproduce eval results for our Zamba2 model with lm_eval both for some loglikelihood tasks (winogrande, arc tasks) and generation tasks (like gsm8k), while some loglikelihood tasks are fine for some reason (mmlu, hellaswag).

When I dig deeper and compare the outputs layer by layer with our HF implementation, I see there is a small discrepancy in mamba2 layers starting from the first one, and it accumulates over the whole network. Final logits are within 1% of each other between the two implementations.

Going to check what's going on.

@fabianlim were you able to reproduce bamba results in lm_eval with your vllm implementation?

@fabianlim
Copy link
Contributor

@yury-tokpanov no i have never tried yet reproducing benches on vllm. I have to try it myself

@fabianlim
Copy link
Contributor

fabianlim commented Jan 28, 2025

@yury-tokpanov with @tlrmchlsmth's help we have verified also the gsm8k number for bamba against the published benchmark.

HF: 0.3662
VLLM: 0.3700

@yury-tokpanov
Copy link

yury-tokpanov commented Jan 28, 2025

The computation of gated RMS norm depends on the number of Mamba2 groups: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/layernorm_gated.py#L32 . Our 7B model has 2 groups, so it definitely affects it. I'm still chasing other discrepancies.

Seems like Codestral 7B uses 8 groups, so it'll definitely be an issue for that model as well: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json


# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! I'll take a look tomorrow.

@yury-tokpanov
Copy link

yury-tokpanov commented Jan 29, 2025

After fixing gated rms norm, I was able to match gsm8k results for our 7B model. I still see some tasks numbers being lower for some reason, so going to investigate further. @fabianlim did you compare other evals? The ones that I see are different and on your evals table are ARC-C and Winogrande.

Btw, we just merged Zamba2 implementation into transformers, so once I'm done with vLLM implementation, I'll create a PR here: need to fix/check correctness and evals, add PP support then clean - our architecture uses shared transformer layers with LoRAs in them, so I'll need to think a little bit about how to adapt it to the vLLM style, seems like there is a big refactoring going on, which already caught me with kv cache.

@fabianlim
Copy link
Contributor

fabianlim commented Jan 29, 2025

@yury-tokpanov I can reproduce the arc-challenge results on bamba

HF

2025-01-29:12:20:32,138 INFO     [evaluation_tracker.py:206] Saving results aggregated
2025-01-29:12:20:32,166 INFO     [evaluation_tracker.py:287] Saving per-sample results for: arc_challenge
hf (pretrained=ibm-fms/Bamba-9B,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 25, batch_size: 16

│+-----------------------------------------+----------------------+----------------------+
|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|                                         │
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|                                         │+---------------------------------------------------------------------------------------+
|arc_challenge|      1|none  |    25|acc     |↑  |0.5896|±  |0.0144|                                         │| Processes:                                                                            |
|             |       |none  |    25|acc_norm|↑  |0.6323|±  |0.0141|

VLLM

2025-01-29:12:39:23,549 INFO     [evaluation_tracker.py:206] Saving results aggregated                                                                                                                                      2025-01-29:12:39:23,563 INFO     [evaluation_tracker.py:287] Saving per-sample results for: arc_challenge
vllm (pretrained=ibm-fms/Bamba-9B,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 25, batch_size: 16
|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|                                                                                                                                                        |arc_challenge|      1|none  |    25|acc     |↑  |0.5836|±  |0.0144|
|             |       |none  |    25|acc_norm|↑  |0.6263|±  |0.0141|

Also for winnograde:
HF: 76.87 (taken from published)
VLLM: 76.63

@tlrmchlsmth
Copy link
Collaborator Author

@yury-tokpanov could you share what you did to fix gated rms norm? I don't see n_groups being handled in zamba here https://github.com/huggingface/transformers/blob/main/src/transformers/models/zamba2/modeling_zamba2.py#L64-L79

@yury-tokpanov
Copy link

yury-tokpanov commented Jan 29, 2025

@yury-tokpanov could you share what you did to fix gated rms norm? I don't see n_groups being handled in zamba here https://github.com/huggingface/transformers/blob/main/src/transformers/models/zamba2/modeling_zamba2.py#L64-L79

We have a new PR in transformers fixing this issue: huggingface/transformers#35943

I did the same thing for my vLLM implementation, but I ignored TP for now as I've been testing our models with TP=1 so far.

Also, I tested with the original triton implementation from mamba2 repo to make sure I'm getting the same eval results:

  • copy-pasted layernorm_gated.py
  • from vllm.model_executor.layers.mamba.ops.layernorm_gated import RMSNorm as RMSNormGated
  • replace self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size, eps=rms_norm_eps) with self.norm = RMSNormGated(intermediate_size, eps=rms_norm_eps, norm_before_gate=False, group_size=intermediate_size // n_groups)

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth
Copy link
Collaborator Author

tlrmchlsmth commented Jan 30, 2025

Updated to handle groups in Mixer2RMSNormGated -- GSM8K results are much improved:

before:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2244|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.2995|±  |0.0126|

after:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.47|±  |0.0502|
|     |       |strict-match    |     5|exact_match|↑  | 0.47|±  |0.0502|

Will re-run humaneval as well

@tlrmchlsmth
Copy link
Collaborator Author

humaneval results looking much better as well:

humaneval (base tests)
pass@1: 0.643
pass@10:        0.825
humaneval+ (base + extra tests)
pass@1: 0.552
pass@10:        0.748

input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, how did you decide which layers should be quantizeable and which not? Did you run experiments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were a few reasons , including the following: i) llm-compressor does not quantize conv1d layers, 2) the conv1d kernels do not support fp8. In our blog we have a few numbers when using this scheme. cc: @nwang-ibm

@yury-tokpanov
Copy link

yury-tokpanov commented Feb 1, 2025

I rebased using the latest version of this PR, and now I'm getting this error from torch.ops._vllm_fa2_C.varlen_fwd() in vllm/vllm_flash_attn/flash_attn_interface.py:173 even though our headdim=224 which is divisible by 32:

This flash attention build does not support headdim not being a multiple of 32.

I see there was this FA3 revert commit for ViT MHA, reporting the same error: #12445

A bit weird, since I'm not using FA3, but something is broken nonetheless.

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: Codestral Mamba
4 participants