From 35c3e173ea3af51d2e7fcc3f2794d41f238d050f Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 7 Feb 2025 09:27:58 -0800 Subject: [PATCH] update ultravox model to support v0.5 release Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 39 +++++++++++++++------ vllm/transformers_utils/configs/ultravox.py | 10 ++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9da0682cfa866..74f69a01e8922 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,7 @@ from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader @@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: return audio_embeds +class FlippedSiluAndMul(SiluAndMul): + """Ultravox is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + + class UltravoxProjector(nn.Module): def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size self._pad_and_stack = StackAudioFrames(config.stack_factor) - dim = config.audio_config.hidden_size * config.stack_factor - self.ln_pre = RMSNorm(dim) - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) - dim = self.hidden_dim + dim_in = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim_in) + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) + dim_mid = self.hidden_dim if config.projector_act == "swiglu": - self.act = MulAndSilu() - dim = dim // 2 + self.act = FlippedSiluAndMul() + dim_mid = dim_mid // 2 else: self.act = get_act_fn(config.projector_act) - self.linear_2 = nn.Linear(dim, - config.text_config.hidden_size, - bias=False) - self.ln_post = RMSNorm(config.text_config.hidden_size) + dim_out = config.text_config.hidden_size + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) + + # Ultravox v0.4.1 and below uses layer_norm after the second linear layer, + # while v0.5.0 and above uses layer_norm after the first linear layer. + if config.projector_ln_mid: + self.ln_mid: nn.Module = RMSNorm(dim_mid) + self.ln_post = nn.Identity() + else: + self.ln_mid = nn.Identity() + self.ln_post = RMSNorm(dim_out) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) + hidden_states = self.ln_mid(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_post(hidden_states) return hidden_states diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 99715ba6d0b09..a5643f54ebcef 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -37,6 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig): The LoRA configuration for finetuning the text model. audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. + projector_ln_mid (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization at the middle of the + projector or at the end. Versions v0.4.1 and below + use `False`, but v0.5 and above use `True`. + audio_latency_block_size (`int`, *optional*, defaults to `None`): + The latency block size for simulating audio streaming. """ model_type = "ultravox" @@ -56,6 +62,8 @@ def __init__( projector_act: str = "swiglu", text_model_lora_config: Optional[Dict[str, Any]] = None, audio_model_lora_config: Optional[Dict[str, Any]] = None, + projector_ln_mid: bool = False, + audio_latency_block_size: Optional[int] = None, **kwargs, ): self.ignore_index = ignore_index @@ -68,6 +76,8 @@ def __init__( self.stack_factor = stack_factor self.norm_init = norm_init self.projector_act = projector_act + self.projector_ln_mid = projector_ln_mid + self.audio_latency_block_size = audio_latency_block_size if text_model_id is not None: # Avoid circular import