diff --git a/qwen3_6_scripts/protocol.py b/qwen3_6_scripts/protocol.py index 456ad51..1efccdd 100644 --- a/qwen3_6_scripts/protocol.py +++ b/qwen3_6_scripts/protocol.py @@ -320,12 +320,20 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs = self.top_logprobs guided_json_object = None - if (self.response_format is not None - and self.response_format.type == "json_object"): - guided_json_object = True + guided_json_from_schema = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif (self.response_format.type == "json_schema" + and self.response_format.json_schema is not None + and self.response_format.json_schema.json_schema is not None): + guided_json_from_schema = \ + self.response_format.json_schema.json_schema guided_decoding = GuidedDecodingParams.from_optional( - json=self._get_guided_json_from_tool() or self.guided_json, + json=(self._get_guided_json_from_tool() + or self.guided_json + or guided_json_from_schema), regex=self.guided_regex, choice=self.guided_choice, grammar=self.guided_grammar, @@ -398,6 +406,10 @@ class ChatCompletionRequest(OpenAIBaseModel): normalized.append(msg) continue if msg.get("content") is None: + if msg.get("reasoning_content") is None: + raise ValueError( + "Each message must have at least one of 'content' or " + "'reasoning_content'.") msg = {**msg, "content": ""} normalized.append(msg) data = {**data, "messages": normalized} @@ -639,12 +651,18 @@ class CompletionRequest(OpenAIBaseModel): echo_without_generation = self.echo and self.max_tokens == 0 guided_json_object = None - if (self.response_format is not None - and self.response_format.type == "json_object"): - guided_json_object = True + guided_json_from_schema = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif (self.response_format.type == "json_schema" + and self.response_format.json_schema is not None + and self.response_format.json_schema.json_schema is not None): + guided_json_from_schema = \ + self.response_format.json_schema.json_schema guided_decoding = GuidedDecodingParams.from_optional( - json=self.guided_json, + json=self.guided_json or guided_json_from_schema, regex=self.guided_regex, choice=self.guided_choice, grammar=self.guided_grammar, diff --git a/qwen3_6_scripts/serving_chat.py b/qwen3_6_scripts/serving_chat.py index c4d24fa..988905d 100644 --- a/qwen3_6_scripts/serving_chat.py +++ b/qwen3_6_scripts/serving_chat.py @@ -294,12 +294,12 @@ class OpenAIServingChat(OpenAIServing): if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation, tokenizer, - request_metadata) + request_metadata, raw_request=raw_request) try: return await self.chat_completion_full_generator( request, result_generator, request_id, conversation, tokenizer, - request_metadata) + request_metadata, raw_request=raw_request) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -317,6 +317,7 @@ class OpenAIServingChat(OpenAIServing): conversation: List[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, + raw_request: Optional[Request] = None, ) -> AsyncGenerator[str, None]: model_name = self.base_model_paths[0].name created_time = int(time.time()) @@ -390,6 +391,27 @@ class OpenAIServingChat(OpenAIServing): yield "data: [DONE]\n\n" return + # Background task: poll is_disconnected() every 300 ms and abort the + # engine request as soon as the client goes away. This catches the + # case where the HTTP layer (Starlette/uvicorn) does not actively read + # the receive channel during streaming, so is_disconnected() in + # iterate_with_cancellation never fires during fast decode. + _disconnect_watcher: Optional[asyncio.Task] = None + if raw_request is not None: + async def _watch_disconnect() -> None: + try: + while True: + if await raw_request.is_disconnected(): + logger.info( + "Client disconnected (decode watcher), " + "aborting request %s", request_id) + await self.engine_client.abort(request_id) + return + await asyncio.sleep(0.3) + except asyncio.CancelledError: + pass + _disconnect_watcher = asyncio.ensure_future(_watch_disconnect()) + try: async for res in result_generator: if res.prompt_token_ids is not None: @@ -732,7 +754,7 @@ class OpenAIServingChat(OpenAIServing): reasoning_tokens=total_reasoning) except asyncio.CancelledError: - # Client disconnected; abort the engine request so GPU is freed. + # Client disconnected via CancelledError path; abort engine request. await self.engine_client.abort(request_id) return except ValueError as e: @@ -740,6 +762,18 @@ class OpenAIServingChat(OpenAIServing): logger.error("error in chat completion stream generator: %s", e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" + finally: + # Stop the disconnect watcher (it may already be done if it fired). + if _disconnect_watcher is not None and not _disconnect_watcher.done(): + _disconnect_watcher.cancel() + try: + await _disconnect_watcher + except asyncio.CancelledError: + pass + # Covers GeneratorExit when Starlette calls aclose() on disconnect + # during decode (tokens arrive fast so CancelledError path is not + # always triggered). abort() is a no-op for already-finished requests. + await self.engine_client.abort(request_id) # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" @@ -751,18 +785,47 @@ class OpenAIServingChat(OpenAIServing): conversation: List[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, + raw_request: Optional[Request] = None, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.base_model_paths[0].name created_time = int(time.time()) final_res: Optional[RequestOutput] = None + # Background watcher: same logic as the streaming path — polls + # is_disconnected() every 300 ms so that a client disconnect during + # non-streaming decode is caught even when uvicorn isn't actively + # reading the receive channel. + _disconnect_watcher: Optional[asyncio.Task] = None + if raw_request is not None: + async def _watch_disconnect() -> None: + try: + while True: + if await raw_request.is_disconnected(): + logger.info( + "Client disconnected (non-stream watcher), " + "aborting request %s", request_id) + await self.engine_client.abort(request_id) + return + await asyncio.sleep(0.3) + except asyncio.CancelledError: + pass + _disconnect_watcher = asyncio.ensure_future(_watch_disconnect()) + try: async for res in result_generator: final_res = res except asyncio.CancelledError: await self.engine_client.abort(request_id) return self.create_error_response("Client disconnected") + finally: + if _disconnect_watcher is not None and not _disconnect_watcher.done(): + _disconnect_watcher.cancel() + try: + await _disconnect_watcher + except asyncio.CancelledError: + pass + await self.engine_client.abort(request_id) assert final_res is not None