fix multiple issues

This commit is contained in:
2026-06-26 17:23:55 +08:00
parent 810874ddb8
commit f89bc60d59
2 changed files with 92 additions and 11 deletions

View File

@@ -294,12 +294,12 @@ class OpenAIServingChat(OpenAIServing):
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request_metadata, raw_request=raw_request)
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request_metadata, raw_request=raw_request)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -317,6 +317,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
raw_request: Optional[Request] = None,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
@@ -390,6 +391,27 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
return
# Background task: poll is_disconnected() every 300 ms and abort the
# engine request as soon as the client goes away. This catches the
# case where the HTTP layer (Starlette/uvicorn) does not actively read
# the receive channel during streaming, so is_disconnected() in
# iterate_with_cancellation never fires during fast decode.
_disconnect_watcher: Optional[asyncio.Task] = None
if raw_request is not None:
async def _watch_disconnect() -> None:
try:
while True:
if await raw_request.is_disconnected():
logger.info(
"Client disconnected (decode watcher), "
"aborting request %s", request_id)
await self.engine_client.abort(request_id)
return
await asyncio.sleep(0.3)
except asyncio.CancelledError:
pass
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
@@ -732,7 +754,7 @@ class OpenAIServingChat(OpenAIServing):
reasoning_tokens=total_reasoning)
except asyncio.CancelledError:
# Client disconnected; abort the engine request so GPU is freed.
# Client disconnected via CancelledError path; abort engine request.
await self.engine_client.abort(request_id)
return
except ValueError as e:
@@ -740,6 +762,18 @@ class OpenAIServingChat(OpenAIServing):
logger.error("error in chat completion stream generator: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
finally:
# Stop the disconnect watcher (it may already be done if it fired).
if _disconnect_watcher is not None and not _disconnect_watcher.done():
_disconnect_watcher.cancel()
try:
await _disconnect_watcher
except asyncio.CancelledError:
pass
# Covers GeneratorExit when Starlette calls aclose() on disconnect
# during decode (tokens arrive fast so CancelledError path is not
# always triggered). abort() is a no-op for already-finished requests.
await self.engine_client.abort(request_id)
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
@@ -751,18 +785,47 @@ class OpenAIServingChat(OpenAIServing):
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
# Background watcher: same logic as the streaming path — polls
# is_disconnected() every 300 ms so that a client disconnect during
# non-streaming decode is caught even when uvicorn isn't actively
# reading the receive channel.
_disconnect_watcher: Optional[asyncio.Task] = None
if raw_request is not None:
async def _watch_disconnect() -> None:
try:
while True:
if await raw_request.is_disconnected():
logger.info(
"Client disconnected (non-stream watcher), "
"aborting request %s", request_id)
await self.engine_client.abort(request_id)
return
await asyncio.sleep(0.3)
except asyncio.CancelledError:
pass
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
await self.engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
finally:
if _disconnect_watcher is not None and not _disconnect_watcher.done():
_disconnect_watcher.cancel()
try:
await _disconnect_watcher
except asyncio.CancelledError:
pass
await self.engine_client.abort(request_id)
assert final_res is not None