Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Frontend] Improve Shutdown And Logs #11737

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
eb16239
checkpoint prototype
robertgshaw2-redhat Jan 3, 2025
8549fdd
Issue currently is with streaming. The HTTP exception handlers do not…
robertgshaw2-redhat Jan 3, 2025
77801cd
switch from ValueError -> Exception.
robertgshaw2-redhat Jan 4, 2025
1bbc3a4
merged
robertgshaw2-redhat Jan 4, 2025
8eca864
updated
robertgshaw2-redhat Jan 4, 2025
b8c77b3
stash
robertgshaw2-redhat Jan 4, 2025
ce9b8ef
stash
robertgshaw2-redhat Jan 4, 2025
3a760a7
add watchdog
robertgshaw2-redhat Jan 4, 2025
3024da0
updated
robertgshaw2-redhat Jan 4, 2025
5af8189
revert spurious changes
robertgshaw2-redhat Jan 4, 2025
3cb21bb
updated
robertgshaw2-redhat Jan 4, 2025
7c97308
updated
robertgshaw2-redhat Jan 4, 2025
ea6824a
updated
robertgshaw2-redhat Jan 4, 2025
b278065
remove cruft
robertgshaw2-redhat Jan 4, 2025
c004bd4
cruft
robertgshaw2-redhat Jan 4, 2025
2556bc4
stash
robertgshaw2-redhat Jan 4, 2025
db0b9e6
fix llama
robertgshaw2-redhat Jan 4, 2025
f722589
updated
robertgshaw2-redhat Jan 4, 2025
de75cc4
cruft
robertgshaw2-redhat Jan 4, 2025
ba5ca87
cruft
robertgshaw2-redhat Jan 4, 2025
4f6b68a
updated
robertgshaw2-redhat Jan 4, 2025
949d425
updated
robertgshaw2-redhat Jan 4, 2025
f67398b
updated
robertgshaw2-redhat Jan 4, 2025
b3d2994
updated
robertgshaw2-redhat Jan 4, 2025
34a997a
update comment
robertgshaw2-redhat Jan 4, 2025
32cf91b
update comment
robertgshaw2-redhat Jan 4, 2025
c73801c
fix more
robertgshaw2-redhat Jan 4, 2025
1188845
updated
robertgshaw2-redhat Jan 4, 2025
706782c
udpatd
robertgshaw2-redhat Jan 4, 2025
1cc0915
added exception file
robertgshaw2-redhat Jan 4, 2025
8db0eee
updated
robertgshaw2-redhat Jan 4, 2025
2fc8af6
fixt
robertgshaw2-redhat Jan 4, 2025
de39af1
reduce cruft
robertgshaw2-redhat Jan 5, 2025
732ba64
reduce cruft
robertgshaw2-redhat Jan 5, 2025
4372094
cleanup
robertgshaw2-redhat Jan 5, 2025
b9144a3
updated
robertgshaw2-redhat Jan 5, 2025
d90e122
cruft
robertgshaw2-redhat Jan 5, 2025
2bbac31
updated
robertgshaw2-redhat Jan 5, 2025
c40542a
revert changes to server
robertgshaw2-redhat Jan 5, 2025
46734eb
revert debug cruft
robertgshaw2-redhat Jan 5, 2025
f0baffb
fix error
robertgshaw2-redhat Jan 5, 2025
8a7f18e
added tests
robertgshaw2-redhat Jan 5, 2025
a662940
revert
robertgshaw2-redhat Jan 5, 2025
4ee6390
fixed
robertgshaw2-redhat Jan 5, 2025
3e23ee2
updated
robertgshaw2-redhat Jan 5, 2025
45456f9
fixed error
robertgshaw2-redhat Jan 5, 2025
6128b1a
update test coverage
robertgshaw2-redhat Jan 5, 2025
de24559
stash
robertgshaw2-redhat Jan 5, 2025
7adf26e
added tests
robertgshaw2-redhat Jan 6, 2025
bf92854
stash
robertgshaw2-redhat Jan 7, 2025
8dae5c6
updated
robertgshaw2-redhat Feb 7, 2025
6b4fe88
updated
robertgshaw2-redhat Feb 7, 2025
efe85ee
updared
robertgshaw2-redhat Feb 7, 2025
6195795
fix typo
robertgshaw2-redhat Feb 7, 2025
0b25586
updated
robertgshaw2-redhat Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/v1/shutdown/test_forward_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Test that we handle an Error in model forward and shutdown."""

import asyncio

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError


def evil_forward(self, *args, **kwargs):
"""Evil forward method that raise an exception after 5 calls."""
NUMBER_OF_GOOD_PASSES = 10

if not hasattr(self, "num_calls"):
self.num_calls = 0

if (self.num_calls == NUMBER_OF_GOOD_PASSES
and get_tensor_model_parallel_rank() == 0):
raise Exception("Simulated illegal memory access on Rank 0!")
self.num_calls += 1

return self.model(*args, **kwargs, intermediate_tensors=None)


@pytest.mark.asyncio
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size):

if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)

engine_args = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B",
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
async_llm = AsyncLLM.from_engine_args(engine_args)

async def generate(request_id: str):
generator = async_llm.generate("Hello my name is",
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e

NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)

# Every request should have get an EngineDeadError.
for output in outputs:
assert isinstance(output, EngineDeadError)

# AsyncLLM should be errored.
assert async_llm.errored

# We should not be able to make another request.
with pytest.raises(EngineDeadError):
async for _ in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams()):
raise Exception("We should not get here.")

# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)

# NOTE: shutdown is handled by the API Server. If an exception
# occurs, so it is expected that we would need to call this.
async_llm.shutdown()


@pytest.mark.parametrize("enable_multiprocessing", [True, False])
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
def test_llm_model_error(monkeypatch, tensor_parallel_size,
enable_multiprocessing):

if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)

# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)

llm = LLM(model="meta-llama/Llama-3.2-1B",
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)

with pytest.raises(EngineDeadError):
llm.generate("Hello my name is Robert and I")

# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
56 changes: 56 additions & 0 deletions tests/v1/shutdown/test_processor_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test error handling in Processor. Should not impact other reqs."""

import asyncio

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError


@pytest.mark.asyncio
async def test_async_llm_processor_error(monkeypatch):

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

engine_args = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
enforce_eager=True)
async_llm = AsyncLLM.from_engine_args(engine_args)

async def generate(request_id: str):
# [] is not allowed and will raise a ValueError in Processor.
generator = async_llm.generate(TokensPrompt([]),
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e

NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)

# Every request should have get an EngineGenerateError.
for output in outputs:
with pytest.raises(EngineGenerateError):
raise output

# AsyncLLM should be errored.
assert not async_llm.errored

# This should be no problem.
outputs = []
async for out in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(max_tokens=5)):
outputs.append(out)
assert len(outputs) == 5

async_llm.shutdown()
87 changes: 87 additions & 0 deletions tests/v1/shutdown/test_startup_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Test that we handle a startup Error and shutdown."""

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM


def evil_forward(self, *args, **kwargs):
"""Evil forward method that raise an exception."""

if get_tensor_model_parallel_rank() == 0:
raise Exception("Simulated Error in startup!")

return self.model(*args, **kwargs, intermediate_tensors=None)


MODELS = [
"meta-llama/Llama-3.2-1B", # Raises on first fwd pass.
"mistralai/Mixtral-8x22B-Instruct-v0.1" # Causes OOM.
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
def test_async_llm_startup_error(monkeypatch, model, tensor_parallel_size):

if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)

engine_args = AsyncEngineArgs(
model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)

# Confirm we get an exception.
with pytest.raises(Exception, match="initialization failed"):
_ = AsyncLLM.from_engine_args(engine_args)

# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True, False])
def test_llm_startup_error(monkeypatch, model, tensor_parallel_size,
enable_multiprocessing):

if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)

# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)

with pytest.raises(Exception, match="initialization failed"):
_ = LLM(model="meta-llama/Llama-3.2-1B",
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)

# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
98 changes: 62 additions & 36 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError

logger = init_logger(__name__)

Expand All @@ -34,11 +36,14 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):

loop = asyncio.get_running_loop()

watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()

async def dummy_shutdown() -> None:
pass
Expand All @@ -58,48 +63,69 @@ async def dummy_shutdown() -> None:
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()


async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
robertgshaw2-redhat marked this conversation as resolved.
Show resolved Hide resolved
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)


def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored):
server.should_exit = True


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.

EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).

EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.

We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.

If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""

@app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

@app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
Loading
Loading