diff --git a/src/config.py b/src/config.py deleted file mode 100644 index 67b9836..0000000 --- a/src/config.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import json -import logging -from dotenv import load_dotenv -from torch.cuda import device_count -from utils import get_int_bool_env - -class EngineConfig: - def __init__(self): - load_dotenv() - self.hf_home = os.getenv("HF_HOME") - # Check if /local_metadata.json exists - local_metadata = {} - if os.path.exists("/local_metadata.json"): - with open("/local_metadata.json", "r") as f: - local_metadata = json.load(f) - if local_metadata.get("model_name") is None: - raise ValueError("Model name is not found in /local_metadata.json, there was a problem when you baked the model in.") - logging.info("Using baked-in model") - os.environ["TRANSFORMERS_OFFLINE"] = "1" - os.environ["HF_HUB_OFFLINE"] = "1" - - self.model_name_or_path = local_metadata.get("model_name", os.getenv("MODEL_NAME")) - self.model_revision = local_metadata.get("revision", os.getenv("MODEL_REVISION")) - self.tokenizer_name_or_path = local_metadata.get("tokenizer_name", os.getenv("TOKENIZER_NAME")) or self.model_name_or_path - self.tokenizer_revision = local_metadata.get("tokenizer_revision", os.getenv("TOKENIZER_REVISION")) - self.quantization = local_metadata.get("quantization", os.getenv("QUANTIZATION")) - self.config = self._initialize_config() - def _initialize_config(self): - args = { - "model": self.model_name_or_path, - "revision": self.model_revision, - "download_dir": self.hf_home, - "quantization": self.quantization, - "load_format": os.getenv("LOAD_FORMAT", "auto"), - "dtype": os.getenv("DTYPE", "half" if self.quantization else "auto"), - "tokenizer": self.tokenizer_name_or_path, - "tokenizer_revision": self.tokenizer_revision, - "disable_log_stats": get_int_bool_env("DISABLE_LOG_STATS", True), - "disable_log_requests": get_int_bool_env("DISABLE_LOG_REQUESTS", True), - "trust_remote_code": get_int_bool_env("TRUST_REMOTE_CODE", False), - "gpu_memory_utilization": float(os.getenv("GPU_MEMORY_UTILIZATION", 0.95)), - "max_parallel_loading_workers": None if device_count() > 1 or not os.getenv("MAX_PARALLEL_LOADING_WORKERS") else int(os.getenv("MAX_PARALLEL_LOADING_WORKERS")), - "max_model_len": int(os.getenv("MAX_MODEL_LEN")) if os.getenv("MAX_MODEL_LEN") else None, - "tensor_parallel_size": device_count(), - "seed": int(os.getenv("SEED")) if os.getenv("SEED") else None, - "kv_cache_dtype": os.getenv("KV_CACHE_DTYPE"), - "block_size": int(os.getenv("BLOCK_SIZE")) if os.getenv("BLOCK_SIZE") else None, - "swap_space": int(os.getenv("SWAP_SPACE")) if os.getenv("SWAP_SPACE") else None, - "max_seq_len_to_capture": int(os.getenv("MAX_SEQ_LEN_TO_CAPTURE")) if os.getenv("MAX_SEQ_LEN_TO_CAPTURE") else None, - "disable_custom_all_reduce": get_int_bool_env("DISABLE_CUSTOM_ALL_REDUCE", False), - "enforce_eager": get_int_bool_env("ENFORCE_EAGER", False) - } - if args["kv_cache_dtype"] == "fp8_e5m2": - args["kv_cache_dtype"] = "fp8" - logging.warning("Using fp8_e5m2 is deprecated. Please use fp8 instead.") - if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"): - args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) - logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.") - - - return {k: v for k, v in args.items() if v not in [None, ""]} diff --git a/src/engine.py b/src/engine.py index 1ac3f8d..e8fc606 100644 --- a/src/engine.py +++ b/src/engine.py @@ -7,7 +7,7 @@ from typing import AsyncGenerator import time -from vllm import AsyncLLMEngine, AsyncEngineArgs +from vllm import AsyncLLMEngine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse @@ -15,13 +15,13 @@ from utils import DummyRequest, JobInput, BatchSize, create_error_response from constants import DEFAULT_MAX_CONCURRENCY, DEFAULT_BATCH_SIZE, DEFAULT_BATCH_SIZE_GROWTH_FACTOR, DEFAULT_MIN_BATCH_SIZE from tokenizer import TokenizerWrapper -from config import EngineConfig +from engine_args import get_engine_args class vLLMEngine: def __init__(self, engine = None): load_dotenv() # For local development - self.config = EngineConfig().config - self.tokenizer = TokenizerWrapper(self.config.get("tokenizer"), self.config.get("tokenizer_revision"), self.config.get("trust_remote_code")) + self.engine_args = get_engine_args() + self.tokenizer = TokenizerWrapper(self.tokenizer, self.engine_args.tokenizer_revision, self.engine_args.trust_remote_code) self.llm = self._initialize_llm() if engine is None else engine self.max_concurrency = int(os.getenv("MAX_CONCURRENCY", DEFAULT_MAX_CONCURRENCY)) self.default_batch_size = int(os.getenv("DEFAULT_BATCH_SIZE", DEFAULT_BATCH_SIZE)) @@ -102,7 +102,7 @@ async def _generate_vllm(self, llm_input, validated_sampling_params, batch_size, def _initialize_llm(self): try: start = time.time() - engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.config)) + engine = AsyncLLMEngine.from_engine_args(self.engine_args) end = time.time() logging.info(f"Initialized vLLM engine in {end - start:.2f}s") return engine @@ -111,15 +111,11 @@ def _initialize_llm(self): raise e -class OpenAIvLLMEngine: +class OpenAIvLLMEngine(vLLMEngine): def __init__(self, vllm_engine): - self.config = vllm_engine.config - self.llm = vllm_engine.llm - self.served_model_name = os.getenv("OPENAI_SERVED_MODEL_NAME_OVERRIDE") or self.config["model"] + super().__init__(vllm_engine) + self.served_model_name = os.getenv("OPENAI_SERVED_MODEL_NAME_OVERRIDE") or self.engine_args["model"] self.response_role = os.getenv("OPENAI_RESPONSE_ROLE") or "assistant" - self.tokenizer = vllm_engine.tokenizer - self.default_batch_size = vllm_engine.default_batch_size - self.batch_size_growth_factor, self.min_batch_size = vllm_engine.batch_size_growth_factor, vllm_engine.min_batch_size self._initialize_engines() self.raw_openai_output = bool(int(os.getenv("RAW_OPENAI_OUTPUT", 1))) diff --git a/src/engine_args.py b/src/engine_args.py new file mode 100644 index 0000000..ed4e6e3 --- /dev/null +++ b/src/engine_args.py @@ -0,0 +1,58 @@ +import os +import json +import logging +from torch.cuda import device_count +from vllm import AsyncEngineArgs + +env_to_args_map = { + "MODEL_NAME": "model", + "MODEL_REVISION": "revision", + "TOKENIZER_NAME": "tokenizer", + "TOKENIZER_REVISION": "tokenizer_revision", + "QUANTIZATION": "quantization" +} + +def get_local_args(): + if os.path.exists("/local_metadata.json"): + with open("/local_metadata.json", "r") as f: + local_metadata = json.load(f) + if local_metadata.get("model_name") is None: + raise ValueError("Model name is not found in /local_metadata.json, there was a problem when baking the model in.") + else: + local_args = {env_to_args_map[k.upper()]: v for k, v in local_metadata.items() if k in env_to_args_map} + os.environ["TRANSFORMERS_OFFLINE"] = "1" + os.environ["HF_HUB_OFFLINE"] = "1" + return local_args + +def get_engine_args(): + # Start with default args + args = { + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": float(os.getenv("GPU_MEMORY_UTILIZATION", 0.9)), + } + + # Get env args that match keys in AsyncEngineArgs + env_args = {k.lower(): v for k, v in dict(os.environ).items() if k.lower() in AsyncEngineArgs.__dataclass_fields__} + args.update(env_args) + + # Get local args if model is baked in and overwrite env args + local_args = get_local_args() + args.update(local_args) + + # Set tensor parallel size and max parallel loading workers if more than 1 GPU is available + num_gpus = device_count() + if num_gpus > 1: + args["tensor_parallel_size"] = num_gpus + args["max_parallel_loading_workers"] = None + if os.getenv("MAX_PARALLEL_LOADING_WORKERS"): + logging.warning("Overriding MAX_PARALLEL_LOADING_WORKERS with None because more than 1 GPU is available.") + + # Deprecated env args backwards compatibility + if args["kv_cache_dtype"] == "fp8_e5m2": + args["kv_cache_dtype"] = "fp8" + logging.warning("Using fp8_e5m2 is deprecated. Please use fp8 instead.") + if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"): + args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) + logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.") + return AsyncEngineArgs(**args)