Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add /v1/audio/transcriptions OpenAI API endpoint #12909

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .buildkite/asr-eval/run-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
set -x

# Start server
python3 -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 $@ &
server_pid=$!

# Wait for server to start, timeout after 600 seconds
timeout 180 bash -c 'until curl localhost:8000/v1/models; do sleep 4; done' || exit 1

# NOTE: Expected WER measured with hf.transformers equivalent model on same dataset.
# Original dataset split is about 23GB in size, hence we use a pre-filtered slice.
python test_transcription_api_correctness.py -m openai/whisper-large-v3 -dr D4nt3/esb-datasets-earnings22-validation-tiny-filtered --expected-wer 12.744980

# Wait for graceful exit
kill $server_pid
208 changes: 208 additions & 0 deletions .buildkite/asr-eval/test_transcription_api_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
"""
Evaluate Transcription API correctness by computing Word Error Rate (WER)
on a given ASR dataset. When provided, it will also compare the WER against
a baseline.
"""
import asyncio
import io
import time
from argparse import ArgumentParser
from statistics import mean, median
from typing import List, Optional

import librosa
import soundfile
import torch
from datasets import load_dataset
from evaluate import load
from openai import AsyncOpenAI
from transformers import AutoTokenizer

openai_api_base = "http://localhost:8000/v1"
client = AsyncOpenAI(api_key="EMPTY", base_url=openai_api_base)


def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer


async def transcribe_audio(client, tokenizer, y, sr):
status = 200
try:
# Send loaded audio directly instead of loading from disk,
# dont account for that time though
with to_bytes(y, sr) as f:
start_time = time.perf_counter()
transcription = await client.audio.transcriptions.create(
file=f,
model=tokenizer.name_or_path,
language="en",
temperature=0.0,
)
end_time = time.perf_counter()
# NOTE there's no streaming in transcriptions, can't measure ttft
except Exception as e:
print(f"Error: {e}")
status = 500
# Hard check on server working properly
assert status == 200
latency = end_time - start_time
num_output_tokens = len(
tokenizer(transcription.text, add_special_tokens=False).input_ids)
return latency, num_output_tokens, transcription.text


async def bound_transcribe(model_name, sem, client, audio, reference):
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Use semaphore to limit concurrent requests.
async with sem:
result = await transcribe_audio(client, tokenizer, *audio)
# Normalize *english* output/reference for evaluation.
out = tokenizer.normalize(result[2])
ref = tokenizer.normalize(reference)
return result[:2] + (out, ref)


async def process_dataset(model, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request)
tasks: List[asyncio.Task] = []
for sample in data:
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
task = asyncio.create_task(
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
tasks.append(task)
return await asyncio.gather(*tasks)


def print_performance_metrics(results, total_time):
latencies = [res[0] for res in results]
total_tokens = sum([res[1] for res in results])

total = len(results)
print(f"Total Requests: {total}")
print(f"Successful Requests: {len(latencies)}")
print(f"Average Latency: {mean(latencies):.4f} seconds")
print(f"Median Latency: {median(latencies):.4f} seconds")
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1]
print(f"95th Percentile Latency: {perc:.4f} seconds")
# Throughput
req_throughput = len(latencies) / total_time
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s")
throughput = total_tokens / total_time
print(f"Estimated Throughput: {throughput:.2f} tok/s")


def add_duration(sample):
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
return sample


def load_hf_dataset(dataset_repo: str,
dataset_name: str,
split='validation',
**hf_kwargs):
## Load and filter the dataset
dataset = load_dataset(dataset_repo,
dataset_name,
split=split,
**hf_kwargs)
if 'duration_ms' not in dataset[0]:
# compute duration to filter
dataset = dataset.map(add_duration)

# Whisper max supported duration
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
return dataset


def run_evaluation(model: str,
dataset,
n_examples: int = -1,
max_concurrent_reqs: Optional[int] = None,
print_metrics: bool = True):
if n_examples > 0:
dataset = dataset.select(range(n_examples))

# Warmup
_ = asyncio.run(
process_dataset(model, dataset.select(range(1)), max_concurrent_reqs))

start = time.perf_counter()
results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs))
end = time.perf_counter()
total_time = end - start
print(f"Total Test Time: {total_time:.4f} seconds")
if print_metrics:
print_performance_metrics(results, total_time)
# Compute WER
predictions = [res[2] for res in results]
references = [res[3] for res in results]
wer = load("wer")
wer_score = 100 * wer.compute(references=references,
predictions=predictions)
print("WER:", wer_score)
return wer_score


if __name__ == "__main__":
args = ArgumentParser()
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
args.add_argument("-m",
"--model-name",
type=str,
help="Name of the ASR model to evaluate.",
default="openai/whisper-large-v3")
args.add_argument("-dr",
"--dataset-repo",
type=str,
help="Path/repo of the hf asr dataset to test on.")
args.add_argument("-dn",
"--dataset-name",
type=str,
help="Name of the hf asr dataset to test on.")
args.add_argument("--n-examples",
type=int,
help="Limit the number of examples to evaluate on.",
default=-1)
args.add_argument(
"--max-concurrent-request",
type=int,
help="Limit the number of requests sent to the server at the same time"
)
args.add_argument("--expected-wer",
type=float,
help="Expected WER to compare against.")
args.add_argument(
"--extra",
nargs="*",
help="Extra keyword arguments (key=value pairs) to be passed "
"to hf `load_dataset`")
args = args.parse_args()

extra_kwargs = {}
if args.extra:
for item in args.extra:
key, value = item.split("=", 1)
extra_kwargs[key] = value

print("Running evaluation with args", vars(args))
dataset = load_hf_dataset(args.dataset_repo, args.dataset_name,
**extra_kwargs)

if not args.max_concurrent_request:
# No max concurrency
args.max_concurrent_request = args.n_examples if args.n_examples > 0\
else len(dataset)

wer = run_evaluation(args.model_name, dataset, args.n_examples,
args.max_concurrent_request)
if args.expected_wer:
torch.testing.assert_close(wer,
args.expected_wer,
atol=1e-1,
rtol=1e-2)
10 changes: 10 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1

- label: Transcription API correctness
working_dir: "/vllm-workspace/.buildkite/asr-eval"
source_file_dependencies:
- csrc/
- vllm/entrypoints/openai/serving_transcription.py
- vllm/model_executor/models/whisper.py
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh

- label: Encoder Decoder tests # 5min
source_file_dependencies:
- vllm/
Expand Down
11 changes: 11 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ We currently support the following OpenAI APIs:
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).

In addition, we have the following custom APIs:

Expand Down Expand Up @@ -298,6 +300,15 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s

(tokenizer-api)=

### Transcriptions API

Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.

<!-- TODO: api enforced limits + uploading audios -->

Code example: <gh-file:examples/online_serving/openai_transcription_client.py>

### Tokenizer API

Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer).
Expand Down
23 changes: 23 additions & 0 deletions examples/online_serving/openai_transcription_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from openai import OpenAI

from vllm.assets.audio import AudioAsset

mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path()
winning_call = AudioAsset('winning_call').get_asset_path()

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
with open(str(mary_had_lamb), "rb") as f:
transcription = client.audio.transcriptions.create(
file=f,
model="openai/whisper-large-v3",
language="en",
response_format="text",
temperature=0.0)
print("transcription result:", transcription)
1 change: 1 addition & 0 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pqdm
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
torch==2.5.1
torchaudio==2.5.1
Expand Down
86 changes: 86 additions & 0 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0

# imports for guided decoding tests
import io
import json

import librosa
import numpy as np
import openai
import pytest
import soundfile as sf

from vllm.assets.audio import AudioAsset

from ...utils import RemoteOpenAIServer


@pytest.fixture
def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_asset_path()
with open(str(path), "rb") as f:
yield f


@pytest.fixture
def winning_call():
path = AudioAsset('winning_call').get_asset_path()
with open(str(path), "rb") as f:
yield f


@pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
prompt = "THE FIRST WORDS I SPOKE"
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "Mary had a little lamb," in out
# This should "force" whisper to continue prompt in all caps
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_capital = json.loads(transcription_wprompt)['text']
assert prompt not in out_capital


@pytest.mark.asyncio
async def test_bad_requests(mary_had_lamb):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()

# invalid language
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=mary_had_lamb,
language="hh",
temperature=0.0)

# Expect audio too long: repeat the timeseries
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=buffer,
language="en",
temperature=0.0)
Loading
Loading