Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Ultravox Model: Support v0.5 Release #12912

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
Expand Down Expand Up @@ -252,33 +252,50 @@
return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
farzadab marked this conversation as resolved.
Show resolved Hide resolved
"""Ultravox is trained with SwiGLU with flipped halves."""

def forward(self, x: torch.Tensor):
a, b = x.chunk(2, dim=-1)
flipped = torch.cat((b, a), dim=-1)
return super().forward(flipped)


class UltravoxProjector(nn.Module):

def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
dim_in = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim_in)
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
dim_mid = self.hidden_dim

if config.projector_act == "swiglu":
self.act = MulAndSilu()
dim = dim // 2
self.act = FlippedSiluAndMul()
dim_mid = dim_mid // 2
else:
self.act = get_act_fn(config.projector_act)

self.linear_2 = nn.Linear(dim,
config.text_config.hidden_size,
bias=False)
self.ln_post = RMSNorm(config.text_config.hidden_size)
dim_out = config.text_config.hidden_size
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

# Ultravox v0.4.1 and below uses layer_norm after the second linear layer,

Check failure on line 284 in vllm/model_executor/models/ultravox.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/ultravox.py:284:81: E501 Line too long (82 > 80)
# while v0.5.0 and above uses layer_norm after the first linear layer.
if config.projector_ln_mid:
self.ln_mid: nn.Module = RMSNorm(dim_mid)
self.ln_post = nn.Identity()
else:
self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out)

def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.ln_mid(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
Expand Down
10 changes: 10 additions & 0 deletions vllm/transformers_utils/configs/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
projector_ln_mid (`bool`, *optional*, defaults to `False`):
Whether to apply layer normalization at the middle of the
projector or at the end. Versions v0.4.1 and below
use `False`, but v0.5 and above use `True`.
audio_latency_block_size (`int`, *optional*, defaults to `None`):
The latency block size for simulating audio streaming.
"""

model_type = "ultravox"
Expand All @@ -56,6 +62,8 @@ def __init__(
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
projector_ln_mid: bool = False,
audio_latency_block_size: Optional[int] = None,
**kwargs,
):
self.ignore_index = ignore_index
Expand All @@ -68,6 +76,8 @@ def __init__(
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.audio_latency_block_size = audio_latency_block_size

if text_model_id is not None:
# Avoid circular import
Expand Down
Loading