From 35c3e173ea3af51d2e7fcc3f2794d41f238d050f Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 7 Feb 2025 09:27:58 -0800 Subject: [PATCH 1/5] update ultravox model to support v0.5 release Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 39 +++++++++++++++------ vllm/transformers_utils/configs/ultravox.py | 10 ++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9da0682cfa866..74f69a01e8922 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,7 @@ from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader @@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: return audio_embeds +class FlippedSiluAndMul(SiluAndMul): + """Ultravox is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + + class UltravoxProjector(nn.Module): def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size self._pad_and_stack = StackAudioFrames(config.stack_factor) - dim = config.audio_config.hidden_size * config.stack_factor - self.ln_pre = RMSNorm(dim) - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) - dim = self.hidden_dim + dim_in = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim_in) + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) + dim_mid = self.hidden_dim if config.projector_act == "swiglu": - self.act = MulAndSilu() - dim = dim // 2 + self.act = FlippedSiluAndMul() + dim_mid = dim_mid // 2 else: self.act = get_act_fn(config.projector_act) - self.linear_2 = nn.Linear(dim, - config.text_config.hidden_size, - bias=False) - self.ln_post = RMSNorm(config.text_config.hidden_size) + dim_out = config.text_config.hidden_size + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) + + # Ultravox v0.4.1 and below uses layer_norm after the second linear layer, + # while v0.5.0 and above uses layer_norm after the first linear layer. + if config.projector_ln_mid: + self.ln_mid: nn.Module = RMSNorm(dim_mid) + self.ln_post = nn.Identity() + else: + self.ln_mid = nn.Identity() + self.ln_post = RMSNorm(dim_out) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) + hidden_states = self.ln_mid(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_post(hidden_states) return hidden_states diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 99715ba6d0b09..a5643f54ebcef 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -37,6 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig): The LoRA configuration for finetuning the text model. audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. + projector_ln_mid (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization at the middle of the + projector or at the end. Versions v0.4.1 and below + use `False`, but v0.5 and above use `True`. + audio_latency_block_size (`int`, *optional*, defaults to `None`): + The latency block size for simulating audio streaming. """ model_type = "ultravox" @@ -56,6 +62,8 @@ def __init__( projector_act: str = "swiglu", text_model_lora_config: Optional[Dict[str, Any]] = None, audio_model_lora_config: Optional[Dict[str, Any]] = None, + projector_ln_mid: bool = False, + audio_latency_block_size: Optional[int] = None, **kwargs, ): self.ignore_index = ignore_index @@ -68,6 +76,8 @@ def __init__( self.stack_factor = stack_factor self.norm_init = norm_init self.projector_act = projector_act + self.projector_ln_mid = projector_ln_mid + self.audio_latency_block_size = audio_latency_block_size if text_model_id is not None: # Avoid circular import From 323abb00c29e2f20cc17e03c1b51c89ae2524c9e Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 7 Feb 2025 11:09:07 -0800 Subject: [PATCH 2/5] revert to using MulAndSilu instead of FlippedSiluAndMul Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 74f69a01e8922..063997a14a66f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,7 @@ from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader @@ -252,15 +252,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: return audio_embeds -class FlippedSiluAndMul(SiluAndMul): - """Ultravox is trained with SwiGLU with flipped halves.""" - - def forward(self, x: torch.Tensor): - a, b = x.chunk(2, dim=-1) - flipped = torch.cat((b, a), dim=-1) - return super().forward(flipped) - - class UltravoxProjector(nn.Module): def __init__(self, config: UltravoxConfig): @@ -273,7 +264,7 @@ def __init__(self, config: UltravoxConfig): dim_mid = self.hidden_dim if config.projector_act == "swiglu": - self.act = FlippedSiluAndMul() + self.act = MulAndSilu() dim_mid = dim_mid // 2 else: self.act = get_act_fn(config.projector_act) @@ -281,7 +272,7 @@ def __init__(self, config: UltravoxConfig): dim_out = config.text_config.hidden_size self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) - # Ultravox v0.4.1 and below uses layer_norm after the second linear layer, + # Ultravox v0.4.1 and below use layer_norm after the second linear layer # while v0.5.0 and above uses layer_norm after the first linear layer. if config.projector_ln_mid: self.ln_mid: nn.Module = RMSNorm(dim_mid) From a3ea1e9dc77b8a7bfac45c602e50e263c1ded363 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 7 Feb 2025 13:58:57 -0800 Subject: [PATCH 3/5] update tests to use ultravox v0.5 Signed-off-by: Farzad Abdolhosseini --- docs/source/serving/multimodal_inputs.md | 4 ++-- examples/offline_inference/audio_language.py | 4 ++-- .../openai_chat_completion_client_for_multimodal.py | 2 +- tests/distributed/test_pipeline_parallel.py | 4 ++-- tests/entrypoints/openai/test_audio.py | 2 +- tests/entrypoints/test_chat_utils.py | 2 +- tests/models/decoder_only/audio_language/test_ultravox.py | 2 +- tests/models/multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index 217b531e83788..ade59e3773839 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT= ### Audio Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in). -Here is a simple example using Ultravox-v0.3. +Here is a simple example using Ultravox-v0.5-1B. First, launch the OpenAI-compatible server: ```bash -vllm serve fixie-ai/ultravox-v0_3 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b ``` Then, you can use the OpenAI client as follows: diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 707ca9f878961..3e3034a02f0f1 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -24,9 +24,9 @@ # Unless specified, these settings have been tested to work on a single L4. -# Ultravox 0.3 +# Ultravox 0.5-1B def run_ultravox(question: str, audio_count: int): - model_name = "fixie-ai/ultravox-v0_3" + model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index d5f798a8dae62..ecfcf05a90d16 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -12,7 +12,7 @@ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 """ import base64 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5b6741d74efc0..5d7cb9e408909 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -215,7 +215,7 @@ def iter_params(self, model_name: str): "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), - "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True), + "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 # [Encoder-decoder] # TODO: Implement PP # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), @@ -234,7 +234,7 @@ def iter_params(self, model_name: str): # [MULTIMODAL GENERATION] "OpenGVLab/InternVL2-1B", "microsoft/Phi-3-vision-128k-instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", # [LANGUAGE GENERATION - HYBRID ARCH] "ai21labs/Jamba-tiny-dev", ] diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 3459f24834dbc..fe7299a48e6f6 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -11,7 +11,7 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" TEST_AUDIO_URLS = [ AudioAsset("winning_call").url, ] diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5c469007af23e..c52fa905c80b3 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -21,7 +21,7 @@ EXAMPLES_DIR = VLLM_PATH / "examples" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index fe9361d126120..d1f643a8fdb73 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 77cf3442df905..e555e1d6d9872 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -163,7 +163,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3fd94b89c8a60..66b7d3c2e77b5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -267,7 +267,7 @@ def check_available_online( "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3", + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 From da6cabd3c3654a2f75492ffcbaf78a7bf8c32363 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Feb 2025 09:36:34 -0800 Subject: [PATCH 4/5] remove audio_latency_block_size for now Signed-off-by: Farzad Abdolhosseini --- vllm/transformers_utils/configs/ultravox.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index a5643f54ebcef..6b2765db94e78 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -41,8 +41,6 @@ class UltravoxConfig(transformers.PretrainedConfig): Whether to apply layer normalization at the middle of the projector or at the end. Versions v0.4.1 and below use `False`, but v0.5 and above use `True`. - audio_latency_block_size (`int`, *optional*, defaults to `None`): - The latency block size for simulating audio streaming. """ model_type = "ultravox" @@ -63,7 +61,6 @@ def __init__( text_model_lora_config: Optional[Dict[str, Any]] = None, audio_model_lora_config: Optional[Dict[str, Any]] = None, projector_ln_mid: bool = False, - audio_latency_block_size: Optional[int] = None, **kwargs, ): self.ignore_index = ignore_index @@ -77,7 +74,6 @@ def __init__( self.norm_init = norm_init self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid - self.audio_latency_block_size = audio_latency_block_size if text_model_id is not None: # Avoid circular import From 607b70ffb411f6c4511425848524d5a15855375e Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Feb 2025 10:25:09 -0800 Subject: [PATCH 5/5] update ultravox version in supported_models Signed-off-by: Farzad Abdolhosseini --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 32f3e9deff671..45e5f706eb9dc 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `UltravoxModel` * Ultravox * T + AE+ - * `fixie-ai/ultravox-v0_3` + * `fixie-ai/ultravox-v0_5-llama-3_2-1b` * ✅︎ * ✅︎ * ✅︎