Skip to content

Commit

Permalink
update ultravox model to support v0.5 release
Browse files Browse the repository at this point in the history
Signed-off-by: Farzad Abdolhosseini <[email protected]>
  • Loading branch information
farzadab committed Feb 7, 2025
1 parent eaa92d4 commit 35c3e17
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
39 changes: 28 additions & 11 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
Expand Down Expand Up @@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
"""Ultravox is trained with SwiGLU with flipped halves."""

def forward(self, x: torch.Tensor):
a, b = x.chunk(2, dim=-1)
flipped = torch.cat((b, a), dim=-1)
return super().forward(flipped)


class UltravoxProjector(nn.Module):

def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
dim_in = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim_in)
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
dim_mid = self.hidden_dim

if config.projector_act == "swiglu":
self.act = MulAndSilu()
dim = dim // 2
self.act = FlippedSiluAndMul()
dim_mid = dim_mid // 2
else:
self.act = get_act_fn(config.projector_act)

self.linear_2 = nn.Linear(dim,
config.text_config.hidden_size,
bias=False)
self.ln_post = RMSNorm(config.text_config.hidden_size)
dim_out = config.text_config.hidden_size
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

# Ultravox v0.4.1 and below uses layer_norm after the second linear layer,

Check failure on line 284 in vllm/model_executor/models/ultravox.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/ultravox.py:284:81: E501 Line too long (82 > 80)
# while v0.5.0 and above uses layer_norm after the first linear layer.
if config.projector_ln_mid:
self.ln_mid: nn.Module = RMSNorm(dim_mid)
self.ln_post = nn.Identity()
else:
self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out)

def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.ln_mid(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
Expand Down
10 changes: 10 additions & 0 deletions vllm/transformers_utils/configs/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
projector_ln_mid (`bool`, *optional*, defaults to `False`):
Whether to apply layer normalization at the middle of the
projector or at the end. Versions v0.4.1 and below
use `False`, but v0.5 and above use `True`.
audio_latency_block_size (`int`, *optional*, defaults to `None`):
The latency block size for simulating audio streaming.
"""

model_type = "ultravox"
Expand All @@ -56,6 +62,8 @@ def __init__(
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
projector_ln_mid: bool = False,
audio_latency_block_size: Optional[int] = None,
**kwargs,
):
self.ignore_index = ignore_index
Expand All @@ -68,6 +76,8 @@ def __init__(
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.audio_latency_block_size = audio_latency_block_size

if text_model_id is not None:
# Avoid circular import
Expand Down

0 comments on commit 35c3e17

Please sign in to comment.