Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Mamba2 (Codestral Mamba) #9292

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
Run `pytest tests/models/test_mamba.py`.
"""
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal

MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1",
]


# Use lower-level interfaces to create this greedy generator, as mamba will
Expand All @@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
Expand All @@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
input_ids = inputs["input_ids"].to(model.device)

# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_ids = model.generate(input_ids,
max_new_tokens=max_tokens,
do_sample=False)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)

Expand All @@ -50,7 +64,8 @@ def test_models(
) -> None:
hf_outputs = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
# Set max_num_seqs to keep Codestral from going OOM at fp32
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

# This test is for verifying whether the model's extra_repr
Expand Down Expand Up @@ -81,7 +96,7 @@ def test_batching(
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
Expand Down Expand Up @@ -165,7 +180,7 @@ def test_parallel_sampling(
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
Expand Down Expand Up @@ -232,7 +247,7 @@ def test_models_preemption_recompute(
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
Expand Down Expand Up @@ -283,7 +298,7 @@ def test_state_cleanup(
# This test is for verifying that the Mamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel(
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr +
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1,
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
other=0.0).to(tl.float32)

if HAS_SEQ_IDX:
Expand Down Expand Up @@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
p += (s % chunk_size > 0)

# get the dimensions
_s, _e = s // chunk_size + p, e // chunk_size + p + 1
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)

# adjust inidces and offsets
chunk_indices[_s:_e] -= p
Expand Down
Loading