diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 854f4fe4f9195..147971be77600 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -4,6 +4,7 @@ Run `pytest tests/models/test_mamba.py`. """ import pytest +import torch from transformers import AutoModelForCausalLM, AutoTokenizer from vllm.engine.arg_utils import EngineArgs @@ -11,7 +12,14 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"] +MODELS = [ + "state-spaces/mamba-130m-hf", + "tiiuae/falcon-mamba-tiny-dev", + # TODO: Compare to a Mamba2 model. The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups. + # See https://github.com/huggingface/transformers/pull/35943 + # "mistralai/Mamba-Codestral-7B-v0.1", +] # Use lower-level interfaces to create this greedy generator, as mamba will @@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) + # Set the device (GPU if available, else CPU) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + # Generate texts from the prompts outputs = [] for prompt in example_prompts: @@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens): input_ids = inputs["input_ids"].to(model.device) # Generate text using the model's generate method directly - generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_ids = model.generate(input_ids, + max_new_tokens=max_tokens, + do_sample=False) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) @@ -50,7 +64,8 @@ def test_models( ) -> None: hf_outputs = generate_greedy(model, example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype) as vllm_model: + # Set max_num_seqs to keep Codestral from going OOM at fp32 + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) # This test is for verifying whether the model's extra_repr @@ -81,7 +96,7 @@ def test_batching( ) -> None: # To pass the small model tests, we need full precision. for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for prompt in example_prompts: for_loop_outputs.append( vllm_model.generate_greedy([prompt], max_tokens)[0]) @@ -165,7 +180,7 @@ def test_parallel_sampling( max_tokens: int, ) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for_loop_outputs = [] for _ in range(10): for_loop_outputs.append( @@ -232,7 +247,7 @@ def test_models_preemption_recompute( # Tests that outputs are identical with and w/o preemtions (recompute) assert dtype == "float" - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: vllm_model.model.llm_engine.scheduler[ 0].ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( @@ -283,7 +298,7 @@ def test_state_cleanup( # This test is for verifying that the Mamba state is cleaned up between # steps, If its not cleaned, an error would be expected. try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 722fbd714ca8f..7ef5111227eb4 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel( dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) + and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: @@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): p += (s % chunk_size > 0) # get the dimensions - _s, _e = s // chunk_size + p, e // chunk_size + p + 1 + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) # adjust inidces and offsets chunk_indices[_s:_e] -= p diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py new file mode 100644 index 0000000000000..4e2a2c0d7327f --- /dev/null +++ b/vllm/model_executor/models/mamba2.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +"""PyTorch MAMBA2 model.""" +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Mamba2DecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.mixer = MambaMixer2(hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=getattr( + config, "intermediate_size", + config.expand * config.hidden_size), + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.num_heads, + head_dim=config.head_dim, + rms_norm_eps=config.layer_norm_epsilon, + activation=config.hidden_act, + chunk_size=config.chunk_size, + quant_config=quant_config) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + return hidden_states, residual + + +class Mamba2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + assert not is_lora_enabled + + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Mamba2DecoderLayer(config, + quant_config=quant_config), + prefix=f"{prefix}.layers") + + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + for i in range(len(self.layers)): + layer = self.layers[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer), + sequence_idx=seq_idx) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + self.backbone = Mamba2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # For eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 128 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.backbone(input_ids, positions, attn_metadata, + mamba_cache_params, intermediate_tensors, + inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = getattr( + self.config, "intermediate_size", + self.config.expand * self.config.hidden_size) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = ( + self.config.n_groups + + extra_groups_for_head_shards(self.config.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.conv_kernel - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.num_heads, world_size), + self.config.head_dim, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c2d0fae7056c7..53ade3aa9bea9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -71,6 +71,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),