diff --git a/Dockerfile b/Dockerfile index 089b4de..0029433 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install --upgrade -r /requirements.txt # Install vLLM (switching back to pip installs since issues that required building fork are fixed and space optimization is not as important since caching) and FlashInfer -RUN python3 -m pip install vllm==0.5.3.post1 && \ +RUN python3 -m pip install vllm==0.5.4 && \ python3 -m pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3 # Setup for Option 2: Building the Image with the Model included diff --git a/builder/requirements.txt b/builder/requirements.txt index db07106..caf0e03 100644 --- a/builder/requirements.txt +++ b/builder/requirements.txt @@ -7,4 +7,5 @@ packaging typing-extensions==4.7.1 pydantic pydantic-settings -hf-transfer \ No newline at end of file +hf-transfer +transformers==4.43.3 \ No newline at end of file diff --git a/src/engine.py b/src/engine.py index 6962027..569c179 100644 --- a/src/engine.py +++ b/src/engine.py @@ -126,7 +126,7 @@ async def _initialize_engines(self): self.model_config = await self.llm.get_model_config() self.chat_engine = OpenAIServingChat( - engine=self.llm, + async_engine_client=self.llm, model_config=self.model_config, served_model_names=[self.served_model_name], response_role=self.response_role, @@ -136,7 +136,7 @@ async def _initialize_engines(self): request_logger=None ) self.completion_engine = OpenAIServingCompletion( - engine=self.llm, + async_engine_client=self.llm, model_config=self.model_config, served_model_names=[self.served_model_name], lora_modules=[], diff --git a/src/engine_args.py b/src/engine_args.py index 42c7d6d..d9e9a9b 100644 --- a/src/engine_args.py +++ b/src/engine_args.py @@ -162,8 +162,4 @@ def get_engine_args(): args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.") - if "gemma-2" in args.get("model", "").lower(): - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" - logging.info("Using FLASHINFER for gemma-2 model.") - return AsyncEngineArgs(**args)