From 8dc0e277cb9855adbdc3f600480f2962dd5f1e9d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 7 Feb 2025 15:58:22 +0000 Subject: [PATCH] CI correctness tests Signed-off-by: NickLucche --- .buildkite/asr-eval/run-tests.sh | 16 ++ .../test_transcription_api_correctness.py | 208 ++++++++++++++++++ .buildkite/test-pipeline.yaml | 10 + requirements-test.in | 1 + 4 files changed, 235 insertions(+) create mode 100644 .buildkite/asr-eval/run-tests.sh create mode 100644 .buildkite/asr-eval/test_transcription_api_correctness.py diff --git a/.buildkite/asr-eval/run-tests.sh b/.buildkite/asr-eval/run-tests.sh new file mode 100644 index 0000000000000..9cbaf9f7c3fdd --- /dev/null +++ b/.buildkite/asr-eval/run-tests.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -x + +# Start server +python3 -m vllm.entrypoints.openai.api_server --model openai/whisper-large-v3 $@ & +server_pid=$! + +# Wait for server to start, timeout after 600 seconds +timeout 180 bash -c 'until curl localhost:8000/v1/models; do sleep 4; done' || exit 1 + +# NOTE: Expected WER measured with hf.transformers equivalent model on same dataset. +# Original dataset split is about 23GB in size, hence we use a pre-filtered slice. +python test_transcription_api_correctness.py -m openai/whisper-large-v3 -dr D4nt3/esb-datasets-earnings22-validation-tiny-filtered --expected-wer 12.744980 + +# Wait for graceful exit +kill $server_pid diff --git a/.buildkite/asr-eval/test_transcription_api_correctness.py b/.buildkite/asr-eval/test_transcription_api_correctness.py new file mode 100644 index 0000000000000..f3fb98527e3d0 --- /dev/null +++ b/.buildkite/asr-eval/test_transcription_api_correctness.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +""" +import asyncio +import io +import time +from argparse import ArgumentParser +from statistics import mean, median +from typing import List, Optional + +import librosa +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from openai import AsyncOpenAI +from transformers import AutoTokenizer + +openai_api_base = "http://localhost:8000/v1" +client = AsyncOpenAI(api_key="EMPTY", base_url=openai_api_base) + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + status = 200 + try: + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + except Exception as e: + print(f"Error: {e}") + status = 500 + # Hard check on server working properly + assert status == 200 + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, + dataset_name: str, + split='validation', + **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, + dataset_name, + split=split, + **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + dataset, + n_examples: int = -1, + max_concurrent_reqs: Optional[int] = None, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + + # Warmup + _ = asyncio.run( + process_dataset(model, dataset.select(range(1)), max_concurrent_reqs)) + + start = time.perf_counter() + results = asyncio.run(process_dataset(model, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +if __name__ == "__main__": + args = ArgumentParser() + # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. + args.add_argument("-m", + "--model-name", + type=str, + help="Name of the ASR model to evaluate.", + default="openai/whisper-large-v3") + args.add_argument("-dr", + "--dataset-repo", + type=str, + help="Path/repo of the hf asr dataset to test on.") + args.add_argument("-dn", + "--dataset-name", + type=str, + help="Name of the hf asr dataset to test on.") + args.add_argument("--n-examples", + type=int, + help="Limit the number of examples to evaluate on.", + default=-1) + args.add_argument( + "--max-concurrent-request", + type=int, + help="Limit the number of requests sent to the server at the same time" + ) + args.add_argument("--expected-wer", + type=float, + help="Expected WER to compare against.") + args.add_argument( + "--extra", + nargs="*", + help="Extra keyword arguments (key=value pairs) to be passed " + "to hf `load_dataset`") + args = args.parse_args() + + extra_kwargs = {} + if args.extra: + for item in args.extra: + key, value = item.split("=", 1) + extra_kwargs[key] = value + + print("Running evaluation with args", vars(args)) + dataset = load_hf_dataset(args.dataset_repo, args.dataset_name, + **extra_kwargs) + + if not args.max_concurrent_request: + # No max concurrency + args.max_concurrent_request = args.n_examples if args.n_examples > 0\ + else len(dataset) + + wer = run_evaluation(args.model_name, dataset, args.n_examples, + args.max_concurrent_request) + if args.expected_wer: + torch.testing.assert_close(wer, + args.expected_wer, + atol=1e-1, + rtol=1e-2) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d5d02fdeb7f4b..f11b3cbb646ca 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -326,6 +326,16 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: Transcription API correctness + working_dir: "/vllm-workspace/.buildkite/asr-eval" + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/serving_transcription.py + - vllm/model_executor/models/whisper.py + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh + - label: Encoder Decoder tests # 5min source_file_dependencies: - vllm/ diff --git a/requirements-test.in b/requirements-test.in index 229d743ec802b..ecf874ecc50fe 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -19,6 +19,7 @@ pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests +jiwer # required for audio tests timm # required for internvl test torch==2.5.1 torchaudio==2.5.1