fix multiple issues
This commit is contained in:
@@ -294,12 +294,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
request_metadata, raw_request=raw_request)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
request_metadata, raw_request=raw_request)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@@ -317,6 +317,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
@@ -390,6 +391,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Background task: poll is_disconnected() every 300 ms and abort the
|
||||
# engine request as soon as the client goes away. This catches the
|
||||
# case where the HTTP layer (Starlette/uvicorn) does not actively read
|
||||
# the receive channel during streaming, so is_disconnected() in
|
||||
# iterate_with_cancellation never fires during fast decode.
|
||||
_disconnect_watcher: Optional[asyncio.Task] = None
|
||||
if raw_request is not None:
|
||||
async def _watch_disconnect() -> None:
|
||||
try:
|
||||
while True:
|
||||
if await raw_request.is_disconnected():
|
||||
logger.info(
|
||||
"Client disconnected (decode watcher), "
|
||||
"aborting request %s", request_id)
|
||||
await self.engine_client.abort(request_id)
|
||||
return
|
||||
await asyncio.sleep(0.3)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
@@ -732,7 +754,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_tokens=total_reasoning)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected; abort the engine request so GPU is freed.
|
||||
# Client disconnected via CancelledError path; abort engine request.
|
||||
await self.engine_client.abort(request_id)
|
||||
return
|
||||
except ValueError as e:
|
||||
@@ -740,6 +762,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.error("error in chat completion stream generator: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
finally:
|
||||
# Stop the disconnect watcher (it may already be done if it fired).
|
||||
if _disconnect_watcher is not None and not _disconnect_watcher.done():
|
||||
_disconnect_watcher.cancel()
|
||||
try:
|
||||
await _disconnect_watcher
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Covers GeneratorExit when Starlette calls aclose() on disconnect
|
||||
# during decode (tokens arrive fast so CancelledError path is not
|
||||
# always triggered). abort() is a no-op for already-finished requests.
|
||||
await self.engine_client.abort(request_id)
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -751,18 +785,47 @@ class OpenAIServingChat(OpenAIServing):
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
# Background watcher: same logic as the streaming path — polls
|
||||
# is_disconnected() every 300 ms so that a client disconnect during
|
||||
# non-streaming decode is caught even when uvicorn isn't actively
|
||||
# reading the receive channel.
|
||||
_disconnect_watcher: Optional[asyncio.Task] = None
|
||||
if raw_request is not None:
|
||||
async def _watch_disconnect() -> None:
|
||||
try:
|
||||
while True:
|
||||
if await raw_request.is_disconnected():
|
||||
logger.info(
|
||||
"Client disconnected (non-stream watcher), "
|
||||
"aborting request %s", request_id)
|
||||
await self.engine_client.abort(request_id)
|
||||
return
|
||||
await asyncio.sleep(0.3)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
await self.engine_client.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
finally:
|
||||
if _disconnect_watcher is not None and not _disconnect_watcher.done():
|
||||
_disconnect_watcher.cancel()
|
||||
try:
|
||||
await _disconnect_watcher
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.engine_client.abort(request_id)
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user