Skip to content

Commit

Permalink
CI correctness tests
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche committed Feb 7, 2025
1 parent 9126e9e commit 8dc0e27
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 0 deletions.
16 changes: 16 additions & 0 deletions .buildkite/asr-eval/run-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
set -x

# Start server
python3 -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 $@ &
server_pid=$!

# Wait for server to start, timeout after 600 seconds
timeout 180 bash -c 'until curl localhost:8000/v1/models; do sleep 4; done' || exit 1

# NOTE: Expected WER measured with hf.transformers equivalent model on same dataset.
# Original dataset split is about 23GB in size, hence we use a pre-filtered slice.
python test_transcription_api_correctness.py -m openai/whisper-large-v3 -dr D4nt3/esb-datasets-earnings22-validation-tiny-filtered --expected-wer 12.744980

# Wait for graceful exit
kill $server_pid
208 changes: 208 additions & 0 deletions .buildkite/asr-eval/test_transcription_api_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
"""
Evaluate Transcription API correctness by computing Word Error Rate (WER)
on a given ASR dataset. When provided, it will also compare the WER against
a baseline.
"""
import asyncio
import io
import time
from argparse import ArgumentParser
from statistics import mean, median
from typing import List, Optional

import librosa
import soundfile
import torch
from datasets import load_dataset
from evaluate import load
from openai import AsyncOpenAI
from transformers import AutoTokenizer

openai_api_base = "http://localhost:8000/v1"
client = AsyncOpenAI(api_key="EMPTY", base_url=openai_api_base)


def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer


async def transcribe_audio(client, tokenizer, y, sr):
status = 200
try:
# Send loaded audio directly instead of loading from disk,
# dont account for that time though
with to_bytes(y, sr) as f:
start_time = time.perf_counter()
transcription = await client.audio.transcriptions.create(
file=f,
model=tokenizer.name_or_path,
language="en",
temperature=0.0,
)
end_time = time.perf_counter()
# NOTE there's no streaming in transcriptions, can't measure ttft
except Exception as e:
print(f"Error: {e}")
status = 500
# Hard check on server working properly
assert status == 200
latency = end_time - start_time
num_output_tokens = len(
tokenizer(transcription.text, add_special_tokens=False).input_ids)
return latency, num_output_tokens, transcription.text


async def bound_transcribe(model_name, sem, client, audio, reference):
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Use semaphore to limit concurrent requests.
async with sem:
result = await transcribe_audio(client, tokenizer, *audio)
# Normalize *english* output/reference for evaluation.
out = tokenizer.normalize(result[2])
ref = tokenizer.normalize(reference)
return result[:2] + (out, ref)


async def process_dataset(model, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request)
tasks: List[asyncio.Task] = []
for sample in data:
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
task = asyncio.create_task(
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
tasks.append(task)
return await asyncio.gather(*tasks)


def print_performance_metrics(results, total_time):
latencies = [res[0] for res in results]
total_tokens = sum([res[1] for res in results])

total = len(results)
print(f"Total Requests: {total}")
print(f"Successful Requests: {len(latencies)}")
print(f"Average Latency: {mean(latencies):.4f} seconds")
print(f"Median Latency: {median(latencies):.4f} seconds")
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1]
print(f"95th Percentile Latency: {perc:.4f} seconds")
# Throughput
req_throughput = len(latencies) / total_time
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s")
throughput = total_tokens / total_time
print(f"Estimated Throughput: {throughput:.2f} tok/s")


def add_duration(sample):
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
return sample


def load_hf_dataset(dataset_repo: str,
dataset_name: str,
split='validation',
**hf_kwargs):
## Load and filter the dataset
dataset = load_dataset(dataset_repo,
dataset_name,
split=split,
**hf_kwargs)
if 'duration_ms' not in dataset[0]:
# compute duration to filter
dataset = dataset.map(add_duration)

# Whisper max supported duration
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
return dataset


def run_evaluation(model: str,
dataset,
n_examples: int = -1,
max_concurrent_reqs: Optional[int] = None,
print_metrics: bool = True):
if n_examples > 0:
dataset = dataset.select(range(n_examples))

# Warmup
_ = asyncio.run(
process_dataset(model, dataset.select(range(1)), max_concurrent_reqs))

start = time.perf_counter()
results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs))
end = time.perf_counter()
total_time = end - start
print(f"Total Test Time: {total_time:.4f} seconds")
if print_metrics:
print_performance_metrics(results, total_time)
# Compute WER
predictions = [res[2] for res in results]
references = [res[3] for res in results]
wer = load("wer")
wer_score = 100 * wer.compute(references=references,
predictions=predictions)
print("WER:", wer_score)
return wer_score


if __name__ == "__main__":
args = ArgumentParser()
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
args.add_argument("-m",
"--model-name",
type=str,
help="Name of the ASR model to evaluate.",
default="openai/whisper-large-v3")
args.add_argument("-dr",
"--dataset-repo",
type=str,
help="Path/repo of the hf asr dataset to test on.")
args.add_argument("-dn",
"--dataset-name",
type=str,
help="Name of the hf asr dataset to test on.")
args.add_argument("--n-examples",
type=int,
help="Limit the number of examples to evaluate on.",
default=-1)
args.add_argument(
"--max-concurrent-request",
type=int,
help="Limit the number of requests sent to the server at the same time"
)
args.add_argument("--expected-wer",
type=float,
help="Expected WER to compare against.")
args.add_argument(
"--extra",
nargs="*",
help="Extra keyword arguments (key=value pairs) to be passed "
"to hf `load_dataset`")
args = args.parse_args()

extra_kwargs = {}
if args.extra:
for item in args.extra:
key, value = item.split("=", 1)
extra_kwargs[key] = value

print("Running evaluation with args", vars(args))
dataset = load_hf_dataset(args.dataset_repo, args.dataset_name,
**extra_kwargs)

if not args.max_concurrent_request:
# No max concurrency
args.max_concurrent_request = args.n_examples if args.n_examples > 0\
else len(dataset)

wer = run_evaluation(args.model_name, dataset, args.n_examples,
args.max_concurrent_request)
if args.expected_wer:
torch.testing.assert_close(wer,
args.expected_wer,
atol=1e-1,
rtol=1e-2)
10 changes: 10 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1

- label: Transcription API correctness
working_dir: "/vllm-workspace/.buildkite/asr-eval"
source_file_dependencies:
- csrc/
- vllm/entrypoints/openai/serving_transcription.py
- vllm/model_executor/models/whisper.py
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh

- label: Encoder Decoder tests # 5min
source_file_dependencies:
- vllm/
Expand Down
1 change: 1 addition & 0 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pqdm
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
torch==2.5.1
torchaudio==2.5.1
Expand Down

0 comments on commit 8dc0e27

Please sign in to comment.