Skip to content

Commit

Permalink
Allow any vLLM engine args as env vars, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
alpayariyak committed Jul 2, 2024
1 parent 0e1e383 commit a08d83f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 74 deletions.
62 changes: 0 additions & 62 deletions src/config.py

This file was deleted.

20 changes: 8 additions & 12 deletions src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from typing import AsyncGenerator
import time

from vllm import AsyncLLMEngine, AsyncEngineArgs
from vllm import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse

from utils import DummyRequest, JobInput, BatchSize, create_error_response
from constants import DEFAULT_MAX_CONCURRENCY, DEFAULT_BATCH_SIZE, DEFAULT_BATCH_SIZE_GROWTH_FACTOR, DEFAULT_MIN_BATCH_SIZE
from tokenizer import TokenizerWrapper
from config import EngineConfig
from engine_args import get_engine_args

class vLLMEngine:
def __init__(self, engine = None):
load_dotenv() # For local development
self.config = EngineConfig().config
self.tokenizer = TokenizerWrapper(self.config.get("tokenizer"), self.config.get("tokenizer_revision"), self.config.get("trust_remote_code"))
self.engine_args = get_engine_args()
self.tokenizer = TokenizerWrapper(self.tokenizer, self.engine_args.tokenizer_revision, self.engine_args.trust_remote_code)
self.llm = self._initialize_llm() if engine is None else engine
self.max_concurrency = int(os.getenv("MAX_CONCURRENCY", DEFAULT_MAX_CONCURRENCY))
self.default_batch_size = int(os.getenv("DEFAULT_BATCH_SIZE", DEFAULT_BATCH_SIZE))
Expand Down Expand Up @@ -102,7 +102,7 @@ async def _generate_vllm(self, llm_input, validated_sampling_params, batch_size,
def _initialize_llm(self):
try:
start = time.time()
engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.config))
engine = AsyncLLMEngine.from_engine_args(self.engine_args)
end = time.time()
logging.info(f"Initialized vLLM engine in {end - start:.2f}s")
return engine
Expand All @@ -111,15 +111,11 @@ def _initialize_llm(self):
raise e


class OpenAIvLLMEngine:
class OpenAIvLLMEngine(vLLMEngine):
def __init__(self, vllm_engine):
self.config = vllm_engine.config
self.llm = vllm_engine.llm
self.served_model_name = os.getenv("OPENAI_SERVED_MODEL_NAME_OVERRIDE") or self.config["model"]
super().__init__(vllm_engine)
self.served_model_name = os.getenv("OPENAI_SERVED_MODEL_NAME_OVERRIDE") or self.engine_args["model"]
self.response_role = os.getenv("OPENAI_RESPONSE_ROLE") or "assistant"
self.tokenizer = vllm_engine.tokenizer
self.default_batch_size = vllm_engine.default_batch_size
self.batch_size_growth_factor, self.min_batch_size = vllm_engine.batch_size_growth_factor, vllm_engine.min_batch_size
self._initialize_engines()
self.raw_openai_output = bool(int(os.getenv("RAW_OPENAI_OUTPUT", 1)))

Expand Down
58 changes: 58 additions & 0 deletions src/engine_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import json
import logging
from torch.cuda import device_count
from vllm import AsyncEngineArgs

env_to_args_map = {
"MODEL_NAME": "model",
"MODEL_REVISION": "revision",
"TOKENIZER_NAME": "tokenizer",
"TOKENIZER_REVISION": "tokenizer_revision",
"QUANTIZATION": "quantization"
}

def get_local_args():
if os.path.exists("/local_metadata.json"):
with open("/local_metadata.json", "r") as f:
local_metadata = json.load(f)
if local_metadata.get("model_name") is None:
raise ValueError("Model name is not found in /local_metadata.json, there was a problem when baking the model in.")
else:
local_args = {env_to_args_map[k.upper()]: v for k, v in local_metadata.items() if k in env_to_args_map}
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
return local_args

def get_engine_args():
# Start with default args
args = {
"disable_log_stats": True,
"disable_log_requests": True,
"gpu_memory_utilization": float(os.getenv("GPU_MEMORY_UTILIZATION", 0.9)),
}

# Get env args that match keys in AsyncEngineArgs
env_args = {k.lower(): v for k, v in dict(os.environ).items() if k.lower() in AsyncEngineArgs.__dataclass_fields__}
args.update(env_args)

# Get local args if model is baked in and overwrite env args
local_args = get_local_args()
args.update(local_args)

# Set tensor parallel size and max parallel loading workers if more than 1 GPU is available
num_gpus = device_count()
if num_gpus > 1:
args["tensor_parallel_size"] = num_gpus
args["max_parallel_loading_workers"] = None
if os.getenv("MAX_PARALLEL_LOADING_WORKERS"):
logging.warning("Overriding MAX_PARALLEL_LOADING_WORKERS with None because more than 1 GPU is available.")

# Deprecated env args backwards compatibility
if args["kv_cache_dtype"] == "fp8_e5m2":
args["kv_cache_dtype"] = "fp8"
logging.warning("Using fp8_e5m2 is deprecated. Please use fp8 instead.")
if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"):
args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"))
logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.")
return AsyncEngineArgs(**args)

0 comments on commit a08d83f

Please sign in to comment.