fix multiple issues
This commit is contained in:
@@ -320,12 +320,20 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
prompt_logprobs = self.top_logprobs
|
prompt_logprobs = self.top_logprobs
|
||||||
|
|
||||||
guided_json_object = None
|
guided_json_object = None
|
||||||
if (self.response_format is not None
|
guided_json_from_schema = None
|
||||||
and self.response_format.type == "json_object"):
|
if self.response_format is not None:
|
||||||
guided_json_object = True
|
if self.response_format.type == "json_object":
|
||||||
|
guided_json_object = True
|
||||||
|
elif (self.response_format.type == "json_schema"
|
||||||
|
and self.response_format.json_schema is not None
|
||||||
|
and self.response_format.json_schema.json_schema is not None):
|
||||||
|
guided_json_from_schema = \
|
||||||
|
self.response_format.json_schema.json_schema
|
||||||
|
|
||||||
guided_decoding = GuidedDecodingParams.from_optional(
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
json=(self._get_guided_json_from_tool()
|
||||||
|
or self.guided_json
|
||||||
|
or guided_json_from_schema),
|
||||||
regex=self.guided_regex,
|
regex=self.guided_regex,
|
||||||
choice=self.guided_choice,
|
choice=self.guided_choice,
|
||||||
grammar=self.guided_grammar,
|
grammar=self.guided_grammar,
|
||||||
@@ -398,6 +406,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
normalized.append(msg)
|
normalized.append(msg)
|
||||||
continue
|
continue
|
||||||
if msg.get("content") is None:
|
if msg.get("content") is None:
|
||||||
|
if msg.get("reasoning_content") is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Each message must have at least one of 'content' or "
|
||||||
|
"'reasoning_content'.")
|
||||||
msg = {**msg, "content": ""}
|
msg = {**msg, "content": ""}
|
||||||
normalized.append(msg)
|
normalized.append(msg)
|
||||||
data = {**data, "messages": normalized}
|
data = {**data, "messages": normalized}
|
||||||
@@ -639,12 +651,18 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
echo_without_generation = self.echo and self.max_tokens == 0
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
guided_json_object = None
|
guided_json_object = None
|
||||||
if (self.response_format is not None
|
guided_json_from_schema = None
|
||||||
and self.response_format.type == "json_object"):
|
if self.response_format is not None:
|
||||||
guided_json_object = True
|
if self.response_format.type == "json_object":
|
||||||
|
guided_json_object = True
|
||||||
|
elif (self.response_format.type == "json_schema"
|
||||||
|
and self.response_format.json_schema is not None
|
||||||
|
and self.response_format.json_schema.json_schema is not None):
|
||||||
|
guided_json_from_schema = \
|
||||||
|
self.response_format.json_schema.json_schema
|
||||||
|
|
||||||
guided_decoding = GuidedDecodingParams.from_optional(
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
json=self.guided_json,
|
json=self.guided_json or guided_json_from_schema,
|
||||||
regex=self.guided_regex,
|
regex=self.guided_regex,
|
||||||
choice=self.guided_choice,
|
choice=self.guided_choice,
|
||||||
grammar=self.guided_grammar,
|
grammar=self.guided_grammar,
|
||||||
|
|||||||
@@ -294,12 +294,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if request.stream:
|
if request.stream:
|
||||||
return self.chat_completion_stream_generator(
|
return self.chat_completion_stream_generator(
|
||||||
request, result_generator, request_id, conversation, tokenizer,
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
request_metadata)
|
request_metadata, raw_request=raw_request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.chat_completion_full_generator(
|
return await self.chat_completion_full_generator(
|
||||||
request, result_generator, request_id, conversation, tokenizer,
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
request_metadata)
|
request_metadata, raw_request=raw_request)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@@ -317,6 +317,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
|
raw_request: Optional[Request] = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
model_name = self.base_model_paths[0].name
|
model_name = self.base_model_paths[0].name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
@@ -390,6 +391,27 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Background task: poll is_disconnected() every 300 ms and abort the
|
||||||
|
# engine request as soon as the client goes away. This catches the
|
||||||
|
# case where the HTTP layer (Starlette/uvicorn) does not actively read
|
||||||
|
# the receive channel during streaming, so is_disconnected() in
|
||||||
|
# iterate_with_cancellation never fires during fast decode.
|
||||||
|
_disconnect_watcher: Optional[asyncio.Task] = None
|
||||||
|
if raw_request is not None:
|
||||||
|
async def _watch_disconnect() -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if await raw_request.is_disconnected():
|
||||||
|
logger.info(
|
||||||
|
"Client disconnected (decode watcher), "
|
||||||
|
"aborting request %s", request_id)
|
||||||
|
await self.engine_client.abort(request_id)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if res.prompt_token_ids is not None:
|
if res.prompt_token_ids is not None:
|
||||||
@@ -732,7 +754,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
reasoning_tokens=total_reasoning)
|
reasoning_tokens=total_reasoning)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Client disconnected; abort the engine request so GPU is freed.
|
# Client disconnected via CancelledError path; abort engine request.
|
||||||
await self.engine_client.abort(request_id)
|
await self.engine_client.abort(request_id)
|
||||||
return
|
return
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -740,6 +762,18 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.error("error in chat completion stream generator: %s", e)
|
logger.error("error in chat completion stream generator: %s", e)
|
||||||
data = self.create_streaming_error_response(str(e))
|
data = self.create_streaming_error_response(str(e))
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
finally:
|
||||||
|
# Stop the disconnect watcher (it may already be done if it fired).
|
||||||
|
if _disconnect_watcher is not None and not _disconnect_watcher.done():
|
||||||
|
_disconnect_watcher.cancel()
|
||||||
|
try:
|
||||||
|
await _disconnect_watcher
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
# Covers GeneratorExit when Starlette calls aclose() on disconnect
|
||||||
|
# during decode (tokens arrive fast so CancelledError path is not
|
||||||
|
# always triggered). abort() is a no-op for already-finished requests.
|
||||||
|
await self.engine_client.abort(request_id)
|
||||||
# Send the final done message after all response.n are finished
|
# Send the final done message after all response.n are finished
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
@@ -751,18 +785,47 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
|
raw_request: Optional[Request] = None,
|
||||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = self.base_model_paths[0].name
|
model_name = self.base_model_paths[0].name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: Optional[RequestOutput] = None
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
# Background watcher: same logic as the streaming path — polls
|
||||||
|
# is_disconnected() every 300 ms so that a client disconnect during
|
||||||
|
# non-streaming decode is caught even when uvicorn isn't actively
|
||||||
|
# reading the receive channel.
|
||||||
|
_disconnect_watcher: Optional[asyncio.Task] = None
|
||||||
|
if raw_request is not None:
|
||||||
|
async def _watch_disconnect() -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if await raw_request.is_disconnected():
|
||||||
|
logger.info(
|
||||||
|
"Client disconnected (non-stream watcher), "
|
||||||
|
"aborting request %s", request_id)
|
||||||
|
await self.engine_client.abort(request_id)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
_disconnect_watcher = asyncio.ensure_future(_watch_disconnect())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
final_res = res
|
final_res = res
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self.engine_client.abort(request_id)
|
await self.engine_client.abort(request_id)
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
|
finally:
|
||||||
|
if _disconnect_watcher is not None and not _disconnect_watcher.done():
|
||||||
|
_disconnect_watcher.cancel()
|
||||||
|
try:
|
||||||
|
await _disconnect_watcher
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await self.engine_client.abort(request_id)
|
||||||
|
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user