From c956a30a27588297d6570124720a66134f72ce58 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 Jan 2025 17:49:44 +0000 Subject: [PATCH 01/15] Mamba2 changes from #10909 Signed-off-by: Tyler Michael Smith Co-authored-by: Yu Chin Fabian Lim --- tests/kernels/test_mamba_ssm_ssd.py | 302 +++++++ .../layers/mamba/mamba_mixer2.py | 493 ++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 2 +- .../layers/mamba/ops/ssd_bmm.py | 259 ++++++ .../layers/mamba/ops/ssd_chunk_scan.py | 598 ++++++++++++++ .../layers/mamba/ops/ssd_chunk_state.py | 748 ++++++++++++++++++ .../layers/mamba/ops/ssd_combined.py | 221 ++++++ .../layers/mamba/ops/ssd_state_passing.py | 205 +++++ 8 files changed, 2827 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/test_mamba_ssm_ssd.py create mode 100644 vllm/model_executor/layers/mamba/mamba_mixer2.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_bmm.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_combined.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_state_passing.py diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py new file mode 100644 index 0000000000000..820aeb0e46b60 --- /dev/null +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -0,0 +1,302 @@ +from typing import Dict, Tuple + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.platforms import current_platform + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py + + +# this is the segsum implementation taken from above +def segsum(x): + """Calculates segment sum.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C)) + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +def generate_random_inputs(batch_size, + seqlen, + n_heads, + d_head, + itype, + device='cuda'): + + current_platform.seed_everything(0) + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + dt = F.softplus( + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - + 4) + X = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + + return A, dt, X, B, C + + +def generate_continous_batched_examples(example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device='cuda'): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // 4) + + # internal function that outputs a cont batch of examples + # given a tuple of lengths for each example in the batch + # e.g., example_lens=(8, 4) means take 8 samples from first eg, + # 4 examples from second eg, etc + def get_continuous_batch(example_lens: Tuple[int, ...]): + + indices = [] + for i, x in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + x)) + last_taken[i] = (c + x) % full_length + exhausted[i] = last_taken[i] == 0 + + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) + ]).unsqueeze(0) for x in (dt, X, B, C)) + + # internal function that maps "n" to the appropriate right boundary + # value when forming continuous batches from examples of length given + # by "full_length". + # - e.g., when n > full_length, returns n % full_length + # when n == full_length, returns full_length + def end_boundary(n: int): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for spec in example_lens_by_batch: + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = get_continuous_batch(spec) + + # get the metadata + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) + sed_idx = torch.zeros(cu_seqlens[-1], + dtype=torch.int32, + device=cu_seqlens.device) + for i, (srt, end) in enumerate(zip( + cu_seqlens, + cu_seqlens[1:], + )): + sed_idx[srt:end] = i + + # for cont batch + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, + itype): + + # this tests the kernels on a single example (no batching) + + # set seed + batch_size = 1 # batch_size + # ssd_minimal_discrete requires chunk_size divide seqlen + # - this is only required for generating the reference seqs, + # it is not an operational limitation. + seqlen, chunk_size = seq_len_chunk_size + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, + B, C, chunk_size) + + Y, final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True) + + # just test the last in sequence + torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.allclose(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-3, + rtol=1e-3) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("n_heads", [4, 8, 13]) +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize( + "seq_len_chunk_size_cases", + [ + + # small-ish chunk_size (8) + (64, 8, 2, [(64, 32), (64, 32)]), + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary + (64, 8, 2, [(4, 4), (4, 4), (4, 4), + (4, 4)]), # chunk_size larger than cont batches + (64, 8, 5, [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ]), # mode examples with varied lengths + + # odd chunk_size + (64, 29, 2, [(11, 4), (13, 23), (19, 22), + (21, 15)]), # irregular sizes + + # large-ish chunk_size (256) + (64, 256, 1, [(5, ), (1, ), (1, ), + (1, )]), # irregular sizes with small sequences + (64, 256, 2, [(5, 30), (1, 2), (1, 2), + (1, 2)]), # irregular sizes with small sequences + ]) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, + itype): + + # this test with multiple examples in a continuous batch + # (i.e. chunked prefill) + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: Dict = {} # map: eg -> pointer to last taken sample + exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted + + states = None + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + C) in generate_continous_batched_examples( + cases, num_examples, seqlen, + last_taken, exhausted, n_heads, + d_head, itype): + + Y, new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=sed_idx, + return_varlen_states=True, + initial_states=states, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + + # update states + states = new_states + for i, clear in exhausted.items(): + if clear: + states[i].fill_(0.) + exhausted[i] = False diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000000000..c1482a17b0328 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,493 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.tp_size = get_tensor_model_parallel_world_size() + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + assert self.hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if self.tp_size > 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + return tp_size - ngroups % tp_size + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, ratio in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + rank = tp_rank // ratio + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[ + loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + chunk_size: int = 256, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + + self.ssm_state_size = ssm_state_size + self.use_rms_norm = use_rms_norm + self.activation = activation + + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group + ) + intermediate_settings = (intermediate_size, 0, 1) + head_setings = (self.num_heads, 0, 1) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) + + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size, + eps=rms_norm_eps) + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # detect if there are prefills + has_prefill = attn_metadata.num_prefills > 0 + + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=has_initial_states, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + initial_states = None + if has_initial_states is not None and any(has_initial_states): + for idx in mamba_cache_params.state_indices_tensor[ + ~has_initial_states]: + mamba_cache_params.ssm_state[idx].zero_() + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=sequence_idx, + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + n_groups = self.n_groups // self.tp_size + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefil, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 1484b79815ab9..898c2a90553d0 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,5 +1,5 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000000000..20a4e3e6177ef --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,259 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000000000..a4e0b1fc24906 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,598 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - seq_idx_prev points to be previous (possibly logical) chunk. + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=pid_c >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + p = 0 + chunk_indices, chunk_offsets = [], [] + for i, idx in enumerate(seq_idx[0]): + o = i % chunk_size + c = idx > p + if o == 0 or c: + # this means we have a change in sequence + # - that does not accur on the chunk boundary + chunk_indices.append(i // chunk_size) + chunk_offsets.append(o) + + if c: + p = idx # new sequence + + chunk_indices = torch.tensor(chunk_indices, + dtype=torch.int, + device=seq_idx.device) + chunk_offsets = torch.tensor(chunk_offsets, + dtype=torch.int, + device=seq_idx.device) + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000000000..fa65e0d84c64f --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,748 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000000000..1f84ff4e7baef --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,221 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000000000..effa7a76c6875 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states From 17923adcefb6ff1fd6085b5db32a938aed58da97 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 Jan 2025 22:07:54 +0000 Subject: [PATCH 02/15] Get Mamba2 working! Signed-off-by: Tyler Michael Smith --- .../layers/mamba/mamba_mixer2.py | 2 - vllm/model_executor/models/bamba.py | 600 ++++++++++++++++++ vllm/model_executor/models/mamba2.py | 339 ++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 940 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/models/bamba.py create mode 100644 vllm/model_executor/models/mamba2.py diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c1482a17b0328..4c77c49318604 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -174,7 +174,6 @@ def __init__(self, intermediate_size: int, use_conv_bias: bool, use_bias: bool, - use_rms_norm: bool, n_groups: int = 1, num_heads: int = 128, head_dim: int = 64, @@ -203,7 +202,6 @@ def __init__(self, "Tensor parallel world size must divide num heads." self.ssm_state_size = ssm_state_size - self.use_rms_norm = use_rms_norm self.activation = activation self.chunk_size = chunk_size diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000000000..0b2848c0ef99d --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,600 @@ +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + use_rms_norm=True, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + chunk_size=config.mamba_chunk_size, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + if isinstance(layer, BambaAttentionDecoderLayer): + kv_cache = kv_caches[num_attn] + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # follow jamba + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + # for compilation + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # for eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py new file mode 100644 index 0000000000000..df51b01696ea3 --- /dev/null +++ b/vllm/model_executor/models/mamba2.py @@ -0,0 +1,339 @@ +"""PyTorch MAMBA2 model.""" +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Mamba2DecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.mixer = MambaMixer2(hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=getattr( + config, "intermediate_size", + config.expand * config.hidden_size), + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.num_heads, + head_dim=config.head_dim, + rms_norm_eps=config.layer_norm_epsilon, + activation=config.hidden_act, + chunk_size=config.chunk_size, + quant_config=quant_config) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + return hidden_states, residual + + +class Mamba2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + assert not is_lora_enabled + + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Mamba2DecoderLayer(config, + quant_config=quant_config), + prefix=f"{prefix}.layers") + + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + for i in range(len(self.layers)): + layer = self.layers[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer), + sequence_idx=seq_idx) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + self.backbone = Mamba2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) + + hidden_states = self.backbone(input_ids, positions, attn_metadata, + mamba_cache_params, intermediate_tensors, + inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = getattr( + self.config, "intermediate_size", + self.config.expand * self.config.hidden_size) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = ( + self.config.n_groups + + extra_groups_for_head_shards(self.config.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.state_size - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.num_heads, world_size), + self.config.head_dim, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a71f7f7029c7d..ce0cb63ac59c7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -68,6 +68,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), From 4183d45e397f8062da1a311b2a0cfc8acb931105 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 17 Jan 2025 02:03:30 +0000 Subject: [PATCH 03/15] Add integration test -- something is wrong!! Signed-off-by: Tyler Michael Smith --- tests/models/decoder_only/language/test_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 06739e8f02253..31ed9cb801725 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -10,7 +10,7 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"] +MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", "mistralai/Mamba-Codestral-7B-v0.1"] # Use lower-level interfaces to create this greedy generator, as mamba will @@ -38,7 +38,7 @@ def generate_greedy(model_name, example_prompts, max_tokens): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [96]) def test_models( vllm_runner, From 5377644b9be5ec3a414346ce1b9cfafa1ea7ae08 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 17 Jan 2025 02:09:01 +0000 Subject: [PATCH 04/15] format Signed-off-by: Tyler Michael Smith --- .../decoder_only/language/test_hybrid.py | 350 ++++++++++++++++++ .../decoder_only/language/test_mamba.py | 5 +- 2 files changed, 354 insertions(+), 1 deletion(-) create mode 100644 tests/models/decoder_only/language/test_hybrid.py diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py new file mode 100644 index 0000000000000..9ea0c68ab7ec6 --- /dev/null +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -0,0 +1,350 @@ +import pytest + +from tests.utils import multi_gpu_test +from vllm.engine.arg_utils import EngineArgs +from vllm.sampling_params import SamplingParams + +from ...utils import check_outputs_equal + +# This test is for the hybrid models +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-fms/Bamba-9B"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + # numeric error produces different generation + if 'Bamba' in model: + example_prompts.pop(3) + + with hf_runner( + model, + dtype=dtype, + model_kwargs={ + "use_mamba_kernels": + False, # mamba kernels are not installed so HF + # don't use them + }) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float16"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_mamba_prefill_chunking_with_parallel_sampling( + hf_runner, vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int) -> None: + # Tests prefill chunking in conjunction with n>1, in this case, + # prefill is populated with decoding tokens and we test that it + # doesn't fail This test might fail if cache is not allocated + # correctly for n > 1 decoding steps inside a + # chunked prefill forward pass (where we have both prefills + # and decoding together ) + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=30, + max_num_seqs=10 # forces prefill chunks with decoding + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [7]) +def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int) -> None: + # numeric error during prefill chucking produces different generation + # compared to w/o prefill chunking for those examples, removed them for now + if 'Jamba' in model: + example_prompts.pop(7) + example_prompts.pop(2) + example_prompts.pop(1) + elif 'Bamba' in model: + example_prompts.pop(6) + example_prompts.pop(3) + example_prompts.pop(2) + dtype = "half" # use a different dtype for Bamba + + with hf_runner( + model, + dtype=dtype, + model_kwargs={ + "use_mamba_kernels": + False, # mamba kernels are not installed so HF + # don't use them + }) as hf_model: + non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=5, + max_num_seqs=2) as vllm_model: + chunked = vllm_model.generate_greedy(example_prompts, + max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_parallel_sampling( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype) as vllm_model: + for_loop_outputs = [] + for _ in range(10): + for_loop_outputs.append( + # using example_prompts index 1 instead of 0 since with 0 the + # logprobs get really close and the test doesn't pass + vllm_model.generate_greedy([example_prompts[1]], max_tokens) + [0]) + sampling_params = SamplingParams(n=10, + temperature=0.001, + seed=0, + max_tokens=max_tokens) + n_lt_1_outputs = vllm_model.generate([example_prompts[1]], + sampling_params) + token_ids, texts = n_lt_1_outputs[0] + n_lt_1_outputs = [(token_id, text) + for token_id, text in zip(token_ids, texts)] + + check_outputs_equal( + outputs_0_lst=n_lt_1_outputs, + outputs_1_lst=for_loop_outputs, + name_0="vllm_n_lt_1_outputs", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_mamba_cache_cg_padding( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is for verifying that mamba cache is padded to CG captured + # batch size. If it's not, a torch RuntimeError will be raised because + # tensor dimensions aren't compatible + vllm_config = EngineArgs(model=model).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph( + len(example_prompts)): + example_prompts.append(example_prompts[0]) + + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + except RuntimeError: + pytest.fail( + "Couldn't run batch size which is not equal to a Cuda Graph " + "captured batch size. " + "Could be related to mamba cache not padded correctly") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models_preemption_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the hybrid inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum mamba block capacity. + # This could generally happen due to the fact that hybrid does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Hybrid inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Hybrid state is cleaned up between + # steps, If its not cleaned, an error would be expected. + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Hybrid inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is verifying that multistep works correctly + #on mamba-like models + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_hybrid_distributed_produces_identical_generation( + vllm_runner, model: str, dtype: str, max_tokens: int, + example_prompts) -> None: + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_tp_1, + outputs_1_lst=vllm_outputs_tp_2, + name_0="vllm_tp_1", + name_1="vllm_tp_2", + ) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 31ed9cb801725..27a9b5432cb07 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -10,7 +10,10 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", "mistralai/Mamba-Codestral-7B-v0.1"] +MODELS = [ + "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", + "mistralai/Mamba-Codestral-7B-v0.1" +] # Use lower-level interfaces to create this greedy generator, as mamba will From 39f55d1b45e83491498c547276ab4b8a7d756cbe Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 17 Jan 2025 18:01:12 +0000 Subject: [PATCH 05/15] fixes Signed-off-by: Tyler Michael Smith --- tests/models/decoder_only/language/test_mamba.py | 7 ++++--- vllm/model_executor/layers/mamba/mamba_mixer2.py | 5 ++++- vllm/model_executor/models/mamba2.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 27a9b5432cb07..4698d4f8baf07 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -12,7 +12,8 @@ MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "mistralai/Mamba-Codestral-7B-v0.1" + "mistralai/Mamba-Codestral-7B-v0.1", + "/home/tms/mamba2-130m-hf", ] @@ -41,7 +42,7 @@ def generate_greedy(model_name, example_prompts, max_tokens): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) def test_models( vllm_runner, @@ -52,7 +53,7 @@ def test_models( ) -> None: hf_outputs = generate_greedy(model, example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) # This test is for verifying whether the model's extra_repr # can be printed correctly. diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4c77c49318604..426773c5ee45a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -5,6 +5,8 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) from vllm.attention.backends.xformers import XFormersMetadata from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -348,7 +350,8 @@ def forward_cuda( # - currently we really only support the FlashAttention backend has_initial_states = None if (isinstance(attn_metadata, - (FlashAttentionMetadata, XFormersMetadata)) + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) and attn_metadata.context_lens_tensor is not None): has_initial_states = attn_metadata.context_lens_tensor > 0 diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index df51b01696ea3..5827b7ad7753f 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -227,7 +227,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.max_batch_size = vllm_config.pad_for_cudagraph( self.scheduler_config.max_num_seqs) else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = 256 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) From e2e5aacefca91db2e4f5837dec4a321a2851fd11 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 19 Jan 2025 18:23:48 +0000 Subject: [PATCH 06/15] Fix for conv state shape and update placeholder_attn Signed-off-by: Tyler Michael Smith Co-authored-by: Yu Chin Fabian Lim --- vllm/attention/backends/placeholder_attn.py | 143 +++++++++----------- vllm/model_executor/models/jamba.py | 11 +- vllm/model_executor/models/mamba.py | 10 +- vllm/model_executor/models/mamba2.py | 18 +-- vllm/model_executor/models/mamba_cache.py | 5 +- 5 files changed, 76 insertions(+), 111 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..6c05e1e100900 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,5 +1,6 @@ from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -13,6 +14,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -75,11 +77,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -87,31 +84,33 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -123,11 +122,17 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.seq_start_loc is not None + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) @@ -140,15 +145,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_decode_query_len=0, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -166,6 +171,8 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -174,13 +181,16 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -231,8 +241,6 @@ def advance_step(self, assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None - # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -293,9 +301,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -317,15 +322,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -335,59 +331,48 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - + num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) # Placeholders - slot_mapping = torch.empty(0) + slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 890b5530b97d6..b54e892ca1380 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -454,14 +454,9 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 553bc9c28cb21..5bdf058090433 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -231,15 +231,7 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 5827b7ad7753f..5284e11209a35 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -33,7 +33,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] - class Mamba2DecoderLayer(nn.Module): def __init__(self, @@ -226,8 +225,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.max_batch_size = vllm_config.pad_for_cudagraph( self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # For eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs else: - self.max_batch_size = 256 + self.max_batch_size = 8192 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -247,15 +249,7 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, @@ -290,7 +284,7 @@ def _get_mamba_cache_shape( conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) conv_state_shape = ( divide(conv_dim, world_size), - self.config.state_size - 1, + self.config.conv_kernel - 1, ) # These are not TP-ed as they depend on A, dt_bias, D diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 79393421f3ae9..444546082212b 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -40,8 +40,7 @@ def __init__(self, dtype, num_mamba_layers, max_batch_size, self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.free_cache_indices = list(range(max_batch_size)) - def current_run_tensors(self, input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs): + def current_run_tensors(self, **kwargs) -> MambaCacheParams: """ Return the tensors for the current run's conv and ssm state. """ @@ -64,7 +63,7 @@ def current_run_tensors(self, input_ids: torch.Tensor, (mamba_cache_tensors, state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return (mamba_cache_tensors, state_indices_tensor) + return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ From bc1b8afddd1c402a6e2f9066c696831d2c39ad3a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 19 Jan 2025 21:28:53 +0000 Subject: [PATCH 07/15] back out placeholder_attn changes Signed-off-by: Tyler Michael Smith --- .../decoder_only/language/test_mamba.py | 5 +- vllm/attention/backends/placeholder_attn.py | 143 ++++++++++-------- vllm/model_executor/models/mamba2.py | 1 + vllm/model_executor/models/mamba_cache.py | 4 +- 4 files changed, 84 insertions(+), 69 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 4698d4f8baf07..75994f36147d4 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -12,8 +12,7 @@ MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "mistralai/Mamba-Codestral-7B-v0.1", - "/home/tms/mamba2-130m-hf", + "mistralai/Mamba-Codestral-7B-v0.1" ] @@ -125,7 +124,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 8]) def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, chunked_prefill_token_size: int) -> None: diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 6c05e1e100900..534f79b3a60bf 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,6 +1,5 @@ from collections import defaultdict from dataclasses import dataclass -from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -14,7 +13,6 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -77,6 +75,11 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -84,33 +87,31 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # Placeholder. - block_tables: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -122,17 +123,11 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.seq_start_loc is not None # Placeholders slot_mapping = torch.empty(0) @@ -145,15 +140,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=block_tables, use_cuda_graph=False, ) @@ -171,8 +166,6 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -181,16 +174,13 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=seq_lens_tensor, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, + query_start_loc=None, + seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -241,6 +231,8 @@ def advance_step(self, assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) + assert self.block_tables is not None + # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -301,6 +293,9 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -322,6 +317,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 + logits_soft_cap = getattr(self.runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -331,48 +335,59 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - self.num_prefill_tokens - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + num_decode_tokens = batch_size - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) # Placeholders - slot_mapping_tensor = torch.empty(0) + slot_mapping = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 5284e11209a35..60d4922cad913 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -33,6 +33,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] + class Mamba2DecoderLayer(nn.Module): def __init__(self, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 444546082212b..2cabcc660aedc 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -3,7 +3,6 @@ import torch -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.utils import PAD_SLOT_ID @@ -63,7 +62,8 @@ def current_run_tensors(self, **kwargs) -> MambaCacheParams: (mamba_cache_tensors, state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) + return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], + state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ From cd892836de83aa465fefa17c8ac3bb1053e79df9 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 20 Jan 2025 14:28:16 +0000 Subject: [PATCH 08/15] WIP debugging, restore local mamba and placeholder_attn changes Signed-off-by: Tyler Michael Smith --- .../decoder_only/language/test_mamba.py | 3 +- vllm/attention/backends/placeholder_attn.py | 143 ++++++++---------- 2 files changed, 66 insertions(+), 80 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 75994f36147d4..ecb3455d99284 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -12,7 +12,8 @@ MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "mistralai/Mamba-Codestral-7B-v0.1" + "mistralai/Mamba-Codestral-7B-v0.1", + "/home/tms/mamba2-130m-hf", ] diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..6c05e1e100900 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,5 +1,6 @@ from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -13,6 +14,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -75,11 +77,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -87,31 +84,33 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -123,11 +122,17 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.seq_start_loc is not None + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) @@ -140,15 +145,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_decode_query_len=0, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -166,6 +171,8 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -174,13 +181,16 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -231,8 +241,6 @@ def advance_step(self, assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None - # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -293,9 +301,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -317,15 +322,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -335,59 +331,48 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - + num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) # Placeholders - slot_mapping = torch.empty(0) + slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, From 9a838a3d6764e7723329d3d507d79a7f20ed41e2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 20 Jan 2025 23:27:15 +0000 Subject: [PATCH 09/15] Integration tests are now green Signed-off-by: Tyler Michael Smith --- .../decoder_only/language/test_mamba.py | 26 ++++++++++++------- vllm/attention/backends/placeholder_attn.py | 1 - .../layers/mamba/mamba_mixer2.py | 5 +++- vllm/model_executor/models/mamba2.py | 2 +- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index ecb3455d99284..24f7e1b97cfd5 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -3,6 +3,7 @@ Run `pytest tests/models/test_mamba.py`. """ import pytest +import torch from transformers import AutoModelForCausalLM, AutoTokenizer from vllm.engine.arg_utils import EngineArgs @@ -11,9 +12,9 @@ from ...utils import check_outputs_equal MODELS = [ - "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", + "state-spaces/mamba-130m-hf", + "tiiuae/falcon-mamba-tiny-dev", "mistralai/Mamba-Codestral-7B-v0.1", - "/home/tms/mamba2-130m-hf", ] @@ -24,6 +25,10 @@ def generate_greedy(model_name, example_prompts, max_tokens): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) + # Set the device (GPU if available, else CPU) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + # Generate texts from the prompts outputs = [] for prompt in example_prompts: @@ -32,7 +37,9 @@ def generate_greedy(model_name, example_prompts, max_tokens): input_ids = inputs["input_ids"].to(model.device) # Generate text using the model's generate method directly - generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_ids = model.generate(input_ids, + max_new_tokens=max_tokens, + do_sample=False) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) @@ -53,7 +60,8 @@ def test_models( ) -> None: hf_outputs = generate_greedy(model, example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + # Set max_num_seqs to keep Codestral from going OOM at fp32 + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) # This test is for verifying whether the model's extra_repr # can be printed correctly. @@ -81,7 +89,7 @@ def test_batching( ) -> None: # To pass the small model tests, we need full precision. for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for prompt in example_prompts: for_loop_outputs.append( vllm_model.generate_greedy([prompt], max_tokens)[0]) @@ -125,7 +133,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 8]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, chunked_prefill_token_size: int) -> None: @@ -165,7 +173,7 @@ def test_parallel_sampling( max_tokens: int, ) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for_loop_outputs = [] for _ in range(10): for_loop_outputs.append( @@ -232,7 +240,7 @@ def test_models_preemption_recompute( # Tests that outputs are identical with and w/o preemtions (recompute) assert dtype == "float" - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: vllm_model.model.llm_engine.scheduler[ 0].ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( @@ -283,7 +291,7 @@ def test_state_cleanup( # This test is for verifying that the Mamba state is cleaned up between # steps, If its not cleaned, an error would be expected. try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 6c05e1e100900..65a58d9b63dac 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -77,7 +77,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 426773c5ee45a..53adb5623bd80 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -391,6 +391,9 @@ def forward_cuda( cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc).transpose( 0, 1)[:seq_len] + + # TODO: Why is this needed? + hidden_states_B_C = hidden_states_B_C.contiguous() else: hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, @@ -463,7 +466,7 @@ def forward_cuda( -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches - # - in this case there is no more prefil, so the batches gen + # - in this case there is no more prefill, so the batches gen # 1 token at a time # - thus hidden will be (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 60d4922cad913..545d151fe05c0 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -230,7 +230,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # For eager just take the scheduler_config if avail self.max_batch_size = self.scheduler_config.max_num_seqs else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = 128 + 2 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) From be8318e87016c81193c5d5c40ec879198ce69d24 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 20 Jan 2025 23:38:58 +0000 Subject: [PATCH 10/15] remove bamba-specific files Signed-off-by: Tyler Michael Smith --- .../decoder_only/language/test_hybrid.py | 350 ---------- vllm/model_executor/models/bamba.py | 600 ------------------ 2 files changed, 950 deletions(-) delete mode 100644 tests/models/decoder_only/language/test_hybrid.py delete mode 100644 vllm/model_executor/models/bamba.py diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py deleted file mode 100644 index 9ea0c68ab7ec6..0000000000000 --- a/tests/models/decoder_only/language/test_hybrid.py +++ /dev/null @@ -1,350 +0,0 @@ -import pytest - -from tests.utils import multi_gpu_test -from vllm.engine.arg_utils import EngineArgs -from vllm.sampling_params import SamplingParams - -from ...utils import check_outputs_equal - -# This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-fms/Bamba-9B"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - - # numeric error produces different generation - if 'Bamba' in model: - example_prompts.pop(3) - - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_batching( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # To pass the small model tests, we need full precision. - for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: - for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=for_loop_outputs, - outputs_1_lst=batched_outputs, - name_0="for_loop_vllm", - name_1="batched_vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_mamba_prefill_chunking_with_parallel_sampling( - hf_runner, vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int) -> None: - # Tests prefill chunking in conjunction with n>1, in this case, - # prefill is populated with decoding tokens and we test that it - # doesn't fail This test might fail if cache is not allocated - # correctly for n > 1 decoding steps inside a - # chunked prefill forward pass (where we have both prefills - # and decoding together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [7]) -def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # numeric error during prefill chucking produces different generation - # compared to w/o prefill chunking for those examples, removed them for now - if 'Jamba' in model: - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) - elif 'Bamba' in model: - example_prompts.pop(6) - example_prompts.pop(3) - example_prompts.pop(2) - dtype = "half" # use a different dtype for Bamba - - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: - non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=5, - max_num_seqs=2) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) - - check_outputs_equal( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - # using example_prompts index 1 instead of 0 since with 0 the - # logprobs get really close and the test doesn't pass - vllm_model.generate_greedy([example_prompts[1]], max_tokens) - [0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate([example_prompts[1]], - sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_mamba_cache_cg_padding( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - vllm_config = EngineArgs(model=model).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph( - len(example_prompts)): - example_prompts.append(example_prompts[0]) - - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - except RuntimeError: - pytest.fail( - "Couldn't run batch size which is not equal to a Cuda Graph " - "captured batch size. " - "Could be related to mamba cache not padded correctly") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the hybrid inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that hybrid does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 100, 10) - except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_state_cleanup( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Hybrid state is cleaned up between - # steps, If its not cleaned, an error would be expected. - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - for _ in range(10): - vllm_model.generate_greedy([example_prompts[0]] * 100, 1) - except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_multistep( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is verifying that multistep works correctly - #on mamba-like models - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 10, 1) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness(vllm_runner, model: str, dtype: str, - max_tokens: int, example_prompts) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_outputs_multistep = vllm_model.generate_greedy( - example_prompts, max_tokens) - - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: - vllm_outputs_single_step = vllm_model.generate_greedy( - example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs_multistep, - outputs_1_lst=vllm_outputs_single_step, - name_0="vllm_outputs_multistep", - name_1="vllm_outputs_single_step", - ) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_hybrid_distributed_produces_identical_generation( - vllm_runner, model: str, dtype: str, max_tokens: int, - example_prompts) -> None: - - with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: - vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: - vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs_tp_1, - outputs_1_lst=vllm_outputs_tp_2, - name_0="vllm_tp_1", - name_1="vllm_tp_2", - ) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py deleted file mode 100644 index 0b2848c0ef99d..0000000000000 --- a/vllm/model_executor/models/bamba.py +++ /dev/null @@ -1,600 +0,0 @@ -"""Inference-only Bamba model.""" -# Added by the IBM Team, 2024 -from typing import Iterable, List, Optional, Set, Tuple - -import torch -from torch import nn -from transformers import BambaConfig - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_pp_group -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType - -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class BambaMLP(nn.Module): - - def __init__( - self, - config: BambaConfig, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - input_size=config.hidden_size, - output_sizes=[config.intermediate_size] * 2, - bias=bias, - quant_config=quant_config, - ) - self.down_proj = RowParallelLinear( - input_size=config.intermediate_size, - output_size=config.hidden_size, - bias=bias, - quant_config=quant_config, - ) - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) - x, _ = self.down_proj(x) - return x - - -class BambaMixerDecoderLayer(nn.Module): - - def __init__(self, - config: BambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__() - self.config = config - self.mamba = MambaMixer2(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - use_rms_norm=True, - n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - head_dim=config.mamba_d_head, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - chunk_size=config.mamba_chunk_size, - quant_config=quant_config) - - self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - sequence_idx: Optional[torch.Tensor] = None, - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params, sequence_idx) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual - - -class BambaAttentionDecoderLayer(nn.Module): - - def __init__( - self, - config: BambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if hasattr(config, "partial_rotary_factor"): - rotary_dim = self.head_dim * config.partial_rotary_factor - elif hasattr(config, "attn_rotary_emb"): - rotary_dim = config.attn_rotary_emb # for backward compatibility - else: - rotary_dim = self.head_dim # default - - self.rotary_emb = get_rope( - head_size=self.head_dim, - rotary_dim=rotary_dim, - max_position=max_position_embeddings, - rope_scaling=rope_scaling, - base=rope_theta, - is_neox_style=True, - dtype=torch.get_default_dtype(), # see impl of get_rope - ) - - self.qkv_proj = QKVParallelLinear( - config.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn", - ) - - self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def self_attention( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - **kwargs, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - hidden_states = self.self_attention( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual - - -ALL_DECODER_LAYER_TYPES = { - "attention": BambaAttentionDecoderLayer, - "mamba": BambaMixerDecoderLayer -} - - -class BambaModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - def get_layer(prefix: str): - layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layers_block_type[layer_idx]] - return layer_class( - config, - layer_idx, - cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - seq_idx = None - if attn_metadata.num_prefills > 0: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i - seq_idx.unsqueeze_(0) - - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - residual = None - num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] - kv_cache = None - if isinstance(layer, BambaAttentionDecoderLayer): - kv_cache = kv_caches[num_attn] - num_attn += 1 - - layer_mamba_cache_params = None - if isinstance(layer, BambaMixerDecoderLayer): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - residual=residual, - mamba_cache_params=layer_mamba_cache_params, - sequence_idx=seq_idx, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.final_layernorm(hidden_states, residual) - return hidden_states - - -class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["up_proj", "down_proj"] - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Bamba currently does not support prefix caching" - - self.quant_config = vllm_config.quant_config - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.model = BambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = get_sampler() - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # follow jamba - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - # for compilation - if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - elif self.scheduler_config is not None: - # for eager just take the scheduler_config if avail - self.max_batch_size = self.scheduler_config.max_num_seqs - else: - self.max_batch_size = 8192 + 2 - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - if self.mamba_cache is None: - - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - - self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, - self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params, - intermediate_tensors, inputs_embeds) - - return hidden_states - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.mamba_expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if "A_log" in name: - name = name.replace("A_log", "A") - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params From a65e2cb8677fbbc2c2e26f400e373cee2641b45b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 30 Jan 2025 18:51:00 +0000 Subject: [PATCH 11/15] Handle grouping in Mixer2RMSNormGated Signed-off-by: Tyler Michael Smith --- .../layers/mamba/mamba_mixer2.py | 70 ++++++++++++++----- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 53adb5623bd80..d01c688e393d2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -10,6 +10,7 @@ from vllm.attention.backends.xformers import XFormersMetadata from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -33,15 +34,20 @@ @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): super().__init__() - self.hidden_size = hidden_size - self.variance_epsilon = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) - assert self.hidden_size % self.tp_size== 0,\ + assert self.full_hidden_size % self.tp_size== 0,\ "Tensor parallel world size must divide hidden size." def forward_native( @@ -49,21 +55,50 @@ def forward_native( x: torch.Tensor, gate: torch.Tensor, ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. input_dtype = x.dtype x = x * nn.functional.silu(gate.to(torch.float32)) - if self.tp_size > 1: - # Compute local sum and then reduce to obtain global sum - local_sums = x.pow(2).sum(dim=-1, keepdim=True) - global_sums = tensor_model_parallel_all_reduce(local_sums) - # Calculate the variance - count = self.tp_size * x.shape[-1] - variance = (global_sums / count) - + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) else: - variance = x.pow(2).mean(-1, keepdim=True) + #redundant_tp: bool = self.n_groups % self.tp_size != 0 + redundant_tp: bool = True + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] - x = x * torch.rsqrt(variance + self.variance_epsilon) return self.weight * x.to(input_dtype) def forward_cuda( @@ -72,7 +107,7 @@ def forward_cuda( gate: torch.Tensor, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if self.tp_size > 1: + if self.tp_size > 1 or self.n_groups != 1: return self.forward_native(x, gate) from vllm import _custom_ops as ops @@ -324,7 +359,8 @@ def __init__(self, input_is_parallel=True, quant_config=quant_config) - self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size, + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, eps=rms_norm_eps) def forward_native(self, hidden_states: torch.Tensor, From 0d4bb0f9fc7c16a70ecb1644c3eae1c0208b1eae Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 30 Jan 2025 18:51:39 +0000 Subject: [PATCH 12/15] debug cruft Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d01c688e393d2..055818e748993 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -80,8 +80,7 @@ def forward_native( variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) else: - #redundant_tp: bool = self.n_groups % self.tp_size != 0 - redundant_tp: bool = True + redundant_tp: bool = self.n_groups % self.tp_size != 0 if redundant_tp: # To handle the general case, redundantly apply the variance x = tensor_model_parallel_all_gather(x, -1) From 74f6088ce576e74a627aa865ac95f37dbfa1161e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 30 Jan 2025 18:55:36 +0000 Subject: [PATCH 13/15] Remove codestral integration test Signed-off-by: Tyler Michael Smith --- tests/models/decoder_only/language/test_mamba.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 1c21557c90251..1d32ca30f29bc 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -14,7 +14,10 @@ MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "mistralai/Mamba-Codestral-7B-v0.1", + # TODO: Compare to a Mamba2 model. The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups. + # See https://github.com/huggingface/transformers/pull/35943 + # "mistralai/Mamba-Codestral-7B-v0.1", ] From bcc93e6938c66671a58ddf64f88178e6851c48bb Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 7 Feb 2025 00:14:28 +0000 Subject: [PATCH 14/15] spdx Signed-off-by: Tyler Michael Smith --- vllm/model_executor/models/mamba2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 545d151fe05c0..4e2a2c0d7327f 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """PyTorch MAMBA2 model.""" from typing import Iterable, List, Optional, Set, Tuple From 2e5f7be4928297bf36e0d2ca77980fb80c070e05 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 7 Feb 2025 16:32:16 +0000 Subject: [PATCH 15/15] Fix invalid memory accesses in chunk_scan_fwd See https://github.com/fabianlim/vllm/commit/5d240a005dc0fc63924cd8a28e31311da7d6b9db Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 722fbd714ca8f..7ef5111227eb4 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel( dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) + and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: @@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): p += (s % chunk_size > 0) # get the dimensions - _s, _e = s // chunk_size + p, e // chunk_size + p + 1 + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) # adjust inidces and offsets chunk_indices[_s:_e] -= p