diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py new file mode 100644 index 0000000000000..5017bc21ac71f --- /dev/null +++ b/tests/v1/shutdown/test_forward_error.py @@ -0,0 +1,121 @@ +"""Test that we handle an Error in model forward and shutdown.""" + +import asyncio + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from vllm import LLM, SamplingParams +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineDeadError + + +def evil_forward(self, *args, **kwargs): + """Evil forward method that raise an exception after 5 calls.""" + NUMBER_OF_GOOD_PASSES = 10 + + if not hasattr(self, "num_calls"): + self.num_calls = 0 + + if (self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0): + raise Exception("Simulated illegal memory access on Rank 0!") + self.num_calls += 1 + + return self.model(*args, **kwargs, intermediate_tensors=None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +async def test_async_llm_model_error(monkeypatch, tensor_parallel_size): + + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + engine_args = AsyncEngineArgs( + model="meta-llama/Llama-3.2-1B", + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + generator = async_llm.generate("Hello my name is", + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should have get an EngineDeadError. + for output in outputs: + assert isinstance(output, EngineDeadError) + + # AsyncLLM should be errored. + assert async_llm.errored + + # We should not be able to make another request. + with pytest.raises(EngineDeadError): + async for _ in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams()): + raise Exception("We should not get here.") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + # NOTE: shutdown is handled by the API Server. If an exception + # occurs, so it is expected that we would need to call this. + async_llm.shutdown() + + +@pytest.mark.parametrize("enable_multiprocessing", [True, False]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +def test_llm_model_error(monkeypatch, tensor_parallel_size, + enable_multiprocessing): + + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + llm = LLM(model="meta-llama/Llama-3.2-1B", + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + with pytest.raises(EngineDeadError): + llm.generate("Hello my name is Robert and I") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py new file mode 100644 index 0000000000000..056851025ecae --- /dev/null +++ b/tests/v1/shutdown/test_processor_error.py @@ -0,0 +1,56 @@ +"""Test error handling in Processor. Should not impact other reqs.""" + +import asyncio + +import pytest + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs.data import TokensPrompt +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineGenerateError + + +@pytest.mark.asyncio +async def test_async_llm_processor_error(monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine_args = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", + enforce_eager=True) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + # [] is not allowed and will raise a ValueError in Processor. + generator = async_llm.generate(TokensPrompt([]), + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should have get an EngineGenerateError. + for output in outputs: + with pytest.raises(EngineGenerateError): + raise output + + # AsyncLLM should be errored. + assert not async_llm.errored + + # This should be no problem. + outputs = [] + async for out in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams(max_tokens=5)): + outputs.append(out) + assert len(outputs) == 5 + + async_llm.shutdown() diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py new file mode 100644 index 0000000000000..25f2b77b2f3dd --- /dev/null +++ b/tests/v1/shutdown/test_startup_error.py @@ -0,0 +1,87 @@ +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from vllm import LLM +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + + +def evil_forward(self, *args, **kwargs): + """Evil forward method that raise an exception.""" + + if get_tensor_model_parallel_rank() == 0: + raise Exception("Simulated Error in startup!") + + return self.model(*args, **kwargs, intermediate_tensors=None) + + +MODELS = [ + "meta-llama/Llama-3.2-1B", # Raises on first fwd pass. + "mistralai/Mixtral-8x22B-Instruct-v0.1" # Causes OOM. +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +def test_async_llm_startup_error(monkeypatch, model, tensor_parallel_size): + + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + engine_args = AsyncEngineArgs( + model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm we get an exception. + with pytest.raises(Exception, match="initialization failed"): + _ = AsyncLLM.from_engine_args(engine_args) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True, False]) +def test_llm_startup_error(monkeypatch, model, tensor_parallel_size, + enable_multiprocessing): + + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + with pytest.raises(Exception, match="initialization failed"): + _ = LLM(model="meta-llama/Llama-3.2-1B", + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 351a39525fa62..59dfb16489190 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -11,8 +11,10 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.protocol import EngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -34,11 +36,14 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): loop = asyncio.get_running_loop() + watchdog_task = loop.create_task( + watchdog_loop(server, app.state.engine_client)) server_task = loop.create_task(server.serve()) def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + watchdog_task.cancel() async def dummy_shutdown() -> None: pass @@ -58,48 +63,69 @@ async def dummy_shutdown() -> None: port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() + finally: + watchdog_task.cancel() + + +async def watchdog_loop(server: uvicorn.Server, engine: EngineClient): + """ + # Watchdog task that runs in the background, checking + # for error state in the engine. Needed to trigger shutdown + # if an exception arises is StreamingResponse() generator. + """ + VLLM_WATCHDOG_TIME_S = 5.0 + while True: + await asyncio.sleep(VLLM_WATCHDOG_TIME_S) + terminate_if_errored(server, engine) + + +def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): + """ + See discussions here on shutting down a uvicorn server + https://github.com/encode/uvicorn/discussions/1103 + In this case we cannot await the server shutdown here + because handler must first return to close the connection + for this request. + """ + engine_errored = engine.errored and not engine.is_running + if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored): + server.should_exit = True def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """Adds handlers for fatal errors that should crash the server""" + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ @app.exception_handler(RuntimeError) - async def runtime_error_handler(request: Request, __): - """On generic runtime error, check to see if the engine has died. - It probably has, in which case the server will no longer be able to - handle requests. Trigger a graceful shutdown with a SIGTERM.""" - engine = request.app.state.engine_client - if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored - and not engine.is_running): - logger.fatal("AsyncLLMEngine has failed, terminating server " - "process") - # See discussions here on shutting down a uvicorn server - # https://github.com/encode/uvicorn/discussions/1103 - # In this case we cannot await the server shutdown here because - # this handler must first return to close the connection for - # this request. - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(AsyncEngineDeadError) - async def async_engine_dead_handler(_, __): - """Kill the server if the async engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("AsyncLLMEngine is already dead, terminating server " - "process") - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(MQEngineDeadError) - async def mq_engine_dead_handler(_, __): - """Kill the server if the mq engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("MQLLMEngine is already dead, terminating server " - "process") - server.should_exit = True + @app.exception_handler(EngineDeadError) + @app.exception_handler(EngineGenerateError) + async def runtime_exception_handler(request: Request, __): + terminate_if_errored( + server=server, + engine=request.app.state.engine_client, + ) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3c4e35e4aa274..70cdefc311968 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import os from typing import AsyncGenerator, List, Mapping, Optional, Type, Union import numpy as np @@ -21,8 +19,9 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import cdiv, kill_process_tree -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.utils import cdiv +from vllm.v1.engine.core_client import AsyncMPClient +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -47,8 +46,6 @@ def __init__( start_engine_loop: bool = True, ) -> None: - assert start_engine_loop - self.model_config = vllm_config.model_config self.log_requests = log_requests @@ -80,9 +77,7 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_client( - multiprocess_mode=True, - asyncio_mode=True, + self.engine_core = AsyncMPClient( vllm_config=vllm_config, executor_class=executor_class, ) @@ -139,6 +134,9 @@ async def add_request( ) -> asyncio.Queue[RequestOutput]: """Add new request to the AsyncLLM.""" + if self.errored: + raise EngineDeadError() + # 1) Create a new output queue for the request. if self.output_processor.is_request_active(request_id): raise ValueError(f"Request id {request_id} already running.") @@ -216,11 +214,15 @@ async def generate( while not finished: # Note: drain queue without await if possible (avoids # task switching under load which helps performance). - out = q.get_nowait() if not q.empty() else await q.get() + out = q.get_nowait() if q.qsize() > 0 else await q.get() + if isinstance(out, EngineDeadError): + raise out # Coalesce any additional queued outputs while not q.empty(): next_out = q.get_nowait() + if isinstance(next_out, EngineDeadError): + raise out if sampling_params.output_kind == RequestOutputKind.DELTA: out.add(next_out) else: @@ -231,13 +233,27 @@ async def generate( finished = out.finished yield out - # If the request is disconnected by the client, the - # generate() task will be canceled. So, we abort the - # request if we end up here. + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. except asyncio.CancelledError: await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) + raise + + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed.", request_id) raise + # Error in the generate() task (possibly recoverable). + except Exception as e: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e + async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" @@ -284,8 +300,9 @@ async def _run_output_handler(self): ) except Exception as e: - logger.exception("EngineCore output handler hit an error: %s", e) - kill_process_tree(os.getpid()) + logger.error("AsyncLLM output_handler got an Exception:", + exc_info=e) + self.output_processor.propagate_error(EngineDeadError()) async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" @@ -359,19 +376,19 @@ async def reset_prefix_cache(self) -> None: @property def is_running(self) -> bool: - return True + return not self.errored @property def is_stopped(self) -> bool: - return False + return self.errored @property def errored(self) -> bool: - return False + return self.engine_core.engine_core_errored @property def dead_error(self) -> BaseException: - return Exception() # TODO: implement + return EngineDeadError() async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f3d40aa1e9cb2..0bf0ecb9112dc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,9 +19,10 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, + EngineCoreProfile, EngineCoreRequest, + EngineCoreRequestType, EngineCoreRequestUnion, + EngineCoreResetPrefixCache) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus @@ -152,26 +153,33 @@ def __init__( executor_class: Type[Executor], log_stats: bool = False, ): - super().__init__(vllm_config, executor_class) - - self.log_stats = log_stats - - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() - threading.Thread(target=self.process_input_socket, - args=(input_path, ), - daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_path, ), - daemon=True).start() - - # Send Readiness signal to EngineClient. - ready_pipe.send({"status": "READY"}) + try: + super().__init__(vllm_config, executor_class) + + self.log_stats = log_stats + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue: queue.Queue[ + EngineCoreRequestUnion] = queue.Queue() + self.output_queue: queue.Queue[ + List[EngineCoreOutput]] = queue.Queue() + threading.Thread(target=self.process_input_socket, + args=(input_path, ), + daemon=True).start() + threading.Thread(target=self.process_output_socket, + args=(output_path, ), + daemon=True).start() + + # Send Readiness signal to EngineClient. + ready_pipe.send({"status": "READY"}) + + except Exception as e: + logger.exception("EngineCore got error at startup:", exc_info=e) + ready_pipe.send({"status": "FAILED"}) @staticmethod def run_engine_core(*args, **kwargs): @@ -230,7 +238,7 @@ def run_busy_loop(self): # Break out the loop so we can log_stats in step(). if self.log_stats: break - except BaseException: + except Exception: raise # 2) Handle any new client requests (Abort or Add). diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index cdc63acdb7469..2fe1873b6440f 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,23 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import os import signal import weakref from abc import ABC, abstractmethod -from typing import List, Optional, Type +from typing import List, Optional, Type, Union import zmq import zmq.asyncio from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, - make_zmq_socket) +from vllm.utils import get_open_zmq_ipc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder from vllm.v1.utils import BackgroundProcHandle @@ -146,20 +144,6 @@ def __init__( executor_class: Type[Executor], log_stats: bool, ): - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. We kill the process tree here so that the - # stack trace is very evident. - # TODO(rob): rather than killing the main process, we should - # figure out how to raise an AsyncEngineDeadError and - # handle at the API server level so we can return a better - # error code to the clients calling VLLM. - def sigusr1_handler(signum, frame): - logger.fatal("Got fatal signal from worker processes, shutting " - "down. See stack trace above for root cause issue.") - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGUSR1, sigusr1_handler) - # Serialization setup. self.encoder = PickleEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) @@ -183,6 +167,7 @@ def sigusr1_handler(signum, frame): zmq.constants.PUSH) # Start EngineCore in background process. + self.engine_core_errored = False self.proc_handle = BackgroundProcHandle( input_path=input_path, output_path=output_path, @@ -193,20 +178,44 @@ def sigusr1_handler(signum, frame): "executor_class": executor_class, "log_stats": log_stats, }) + self.proc_handle.wait_for_startup() def shutdown(self): """Clean up background resources.""" - if hasattr(self, "proc_handle"): - self.proc_handle.shutdown() + self.proc_handle.shutdown() self._finalizer() + def _sigusr1_handler(self): + """ + EngineCoreProc sends SIGUSR1 if it encounters an Exception. + Set self in errored state and begin shutdown. + """ + logger.fatal("Got fatal signal from EngineCore, shutting down.") + self.engine_core_errored = True + self.shutdown() + + def _format_exception(self, e: Exception) -> Exception: + """If errored, use EngineDeadError so root cause is clear.""" + + return (EngineDeadError( + "EngineCore encountered an issue. See stack trace " + "for the root cause.", + suppress_context=True) if self.engine_core_errored else e) + class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor]): + + # Setup EngineCore signal handler. + def sigusr1_handler(signum, frame): + self._sigusr1_handler() + + signal.signal(signal.SIGUSR1, sigusr1_handler) + super().__init__( asyncio_mode=False, vllm_config=vllm_config, @@ -216,15 +225,21 @@ def __init__(self, vllm_config: VllmConfig, def get_output(self) -> EngineCoreOutputs: - (frame, ) = self.output_socket.recv_multipart(copy=False) - return self.decoder.decode(frame.buffer) + try: + (frame, ) = self.output_socket.recv_multipart(copy=False) + engine_core_outputs = self.decoder.decode(frame.buffer) + return engine_core_outputs + except Exception as e: + raise self._format_exception(e) from None def _send_input(self, request_type: EngineCoreRequestType, request: EngineCoreRequestUnion) -> None: - - # (RequestType, SerializedRequest) - msg = (request_type.value, self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + try: + # (RequestType, SerializedRequest) + msg = (request_type.value, self.encoder.encode(request)) + self.input_socket.send_multipart(msg, copy=False) + except Exception as e: + raise self._format_exception(e) from None def add_request(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been @@ -250,6 +265,18 @@ class AsyncMPClient(MPClient): def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor]): + + # EngineCore sends SIGUSR1 when it gets an Exception. + # NOTE: super().__init__ blocks the event loop until + # background procs are setup. This handler allows us + # to catch issues during startup (e.g. OOM). We switch + # to a signal handler in the event loop __init__. + def sigusr1_handler(signum, frame): + self._sigusr1_handler() + + # signal.signal(signal.SIGUSR1, sigusr1_handler) + + # Initialize EngineCore + all background processes. super().__init__( asyncio_mode=True, vllm_config=vllm_config, @@ -257,30 +284,49 @@ def __init__(self, vllm_config: VllmConfig, log_stats=True, ) - self.outputs_queue: Optional[asyncio.Queue[bytes]] = None + # ZMQ IO. Run it in background task so that we can + # overlap with AsyncLLM.output_handler_loop. This + # works because ZMQ IO releases the GIL. self.queue_task: Optional[asyncio.Task] = None + self.outputs_queue: asyncio.Queue[Union[EngineCoreOutputs, + Exception]] = asyncio.Queue() + + def shutdown(self): + super().shutdown() + if queue_task := getattr(self, "queue_task", None): + queue_task.cancel() + + async def _process_outputs_socket_loop(self): + try: + while True: + (frame, ) = await self.output_socket.recv_multipart(copy=False) + outputs = self.decoder.decode(frame.buffer) + self.outputs_queue.put_nowait(outputs) + except Exception as e: + self.outputs_queue.put_nowait(e) async def get_output_async(self) -> EngineCoreOutputs: - if self.outputs_queue is None: - # Perform IO in separate task to parallelize as much as possible - self.outputs_queue = asyncio.Queue() - async def process_outputs_socket(): - assert self.outputs_queue is not None - while True: - (frame, ) = await self.output_socket.recv_multipart( - copy=False) - self.outputs_queue.put_nowait(frame.buffer) + # Start output loop on the first call. + if self.queue_task is None: + self.queue_task = asyncio.create_task( + self._process_outputs_socket_loop()) - self.queue_task = asyncio.create_task(process_outputs_socket()) + # NOTE: if an exception arises processing the socket, + # the exception is forwarded to the queue. + outputs = await self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None - return self.decoder.decode(await self.outputs_queue.get()) + return outputs async def _send_input(self, request_type: EngineCoreRequestType, request: EngineCoreRequestUnion) -> None: - - msg = (request_type.value, self.encoder.encode(request)) - await self.input_socket.send_multipart(msg, copy=False) + try: + msg = (request_type.value, self.encoder.encode(request)) + await self.input_socket.send_multipart(msg, copy=False) + except Exception as e: + raise self._format_exception(e) from None async def add_request_async(self, request: EngineCoreRequest) -> None: # NOTE: text prompt is not needed in the core engine as it has been diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py new file mode 100644 index 0000000000000..34ec1f6b0cd08 --- /dev/null +++ b/vllm/v1/engine/exceptions.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +class EngineGenerateError(Exception): + """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass + + +class EngineDeadError(Exception): + """Raised when the EngineCore dies. Unrecoverable.""" + + def __init__(self, *args, suppress_context: bool = False, **kwargs): + super().__init__(args, kwargs) + # Make stack trace clearer when using with LLMEngine by + # silencing irrelevant ZMQError. + self.__suppress_context__ = suppress_context diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5dbf530caa17a..69ff2fa2a8022 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -93,6 +93,13 @@ def get_num_unfinished_requests(self): def has_unfinished_requests(self) -> bool: return len(self.request_states) > 0 + def propagate_error(self, e: Exception): + """Propagate error to all generate() tasks.""" + + for _, state in self.request_states.items(): + assert state.queue is not None + state.queue.put_nowait(e) + def abort_requests( self, request_ids: List[str], diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e3f07172d8cd9..2bf094e9d7265 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -21,6 +21,7 @@ destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) +from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger @@ -48,10 +49,13 @@ def sigusr1_handler(signum, frame): logger.fatal( "MulitprocExecutor got fatal signal from worker processes, " "shutting down. See stack trace above for root cause issue.") - # Propagate error up to parent process. - parent_process = psutil.Process().parent() - parent_process.send_signal(signal.SIGUSR1) + # Shutdown first (avoid SysExit exceptions in __del__). self.shutdown() + if VLLM_ENABLE_V1_MULTIPROCESSING: + # TODO(rob): move this to the VLLMConfig. + # Propagate up if using the mp engine. Note that + # sending in non-mp mode crashes caller process. + psutil.Process().parent().send_signal(signal.SIGUSR1) signal.signal(signal.SIGUSR1, sigusr1_handler) @@ -170,7 +174,7 @@ def _cleanup_sockets(self): def shutdown(self): """Properly shut down the executor and its workers""" - if getattr(self, 'shutting_down', False): + if not getattr(self, 'shutting_down', False): self.shutting_down = True for w in self.workers: w.worker_response_mq = None @@ -319,12 +323,20 @@ def signal_handler(signum, frame): except SystemExit: logger.debug("Worker interrupted.") - except Exception: + except Exception as e: + # Log rather than raise so the stack trace is in order of + # WorkerProc -> EngineCore -> AsyncLLM. + logger.exception("WorkerProc got an Exception:", exc_info=e) + + # The parent will send a SIGTERM to all worker processes + # after we send SIGUSR. Set this value so we don't re-throw + # SystemExit(), to avoid zmq exceptions during __del__. + shutdown_requested = True + # worker_busy_loop sends exceptions exceptons to Executor # for shutdown, but if there is an error in startup or an # error with IPC itself, we need to alert the parent. psutil.Process().parent().send_signal(signal.SIGUSR1) - raise finally: # Clean up once worker exits busy loop @@ -342,7 +354,7 @@ def wait_for_startup( # Wait for Worker to send READY. while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for WorkerProc to startup.") + logger.info("Waiting for WorkerProc to startup.") if not proc.is_alive(): raise RuntimeError("WorkerProc failed to start.") diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5494542c181d7..2f5168296e677 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -102,7 +102,8 @@ def __init__( process_kwargs: Dict[Any, Any], ): context = get_mp_context() - reader, writer = context.Pipe(duplex=False) + self.reader, writer = context.Pipe(duplex=False) + self.process_name = process_name assert ("ready_pipe" not in process_kwargs and "input_path" not in process_kwargs @@ -111,20 +112,38 @@ def __init__( process_kwargs["input_path"] = input_path process_kwargs["output_path"] = output_path + # Flag for shutdown state. BackgroundProcs send signals + # when errors occur which calls shutdown(). If we are in + # startup loop when signaled, this flag breaks us out. + self.shutting_down = False + # Run busy loop in background process. self.proc = context.Process(target=target_fn, kwargs=process_kwargs) self._finalizer = weakref.finalize(self, shutdown, self.proc, input_path, output_path) self.proc.start() - # Wait for startup. - if reader.recv()["status"] != "READY": - raise RuntimeError(f"{process_name} initialization failed. " - "See root cause above.") - def shutdown(self): + self.shutting_down = True self._finalizer() + def wait_for_startup(self): + """Wait until the background process is ready.""" + + e = Exception(f"{self.process_name} initialization failed due to " + "an exception in a background process. See stack trace " + "for root cause.") + + while not self.reader.poll(timeout=1): + if self.shutting_down: + raise e + try: + if self.reader.recv()["status"] != "READY": + raise e + except EOFError: + e.__suppress_context__ = True + raise e from None + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object.