From 1afe3d07987aaae92c9c050092ad64b97b7a0435 Mon Sep 17 00:00:00 2001 From: Xihuai Wang Date: Thu, 27 Mar 2025 15:16:52 +0800 Subject: [PATCH] Align finish reason and stream mode in openai api (#4388) --- python/sglang/srt/openai_api/adapter.py | 148 +++++++++++------------ python/sglang/srt/openai_api/protocol.py | 12 +- test/srt/test_openai_server.py | 7 +- 3 files changed, 87 insertions(+), 80 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 09e7bf3b4..948541aee 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -645,7 +645,7 @@ def v1_generate_response( "index": 0, "text": text, "logprobs": logprobs, - "finish_reason": (finish_reason["type"] if finish_reason else ""), + "finish_reason": finish_reason["type"] if finish_reason else None, "matched_stop": ( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -657,7 +657,7 @@ def v1_generate_response( index=idx, text=text, logprobs=logprobs, - finish_reason=(finish_reason["type"] if finish_reason else ""), + finish_reason=finish_reason["type"] if finish_reason else None, matched_stop=( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -805,7 +805,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): index=index, text=delta, logprobs=logprobs, - finish_reason=(finish_reason["type"] if finish_reason else ""), + finish_reason=finish_reason["type"] if finish_reason else None, matched_stop=( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -1216,7 +1216,7 @@ def v1_chat_generate_response( "reasoning_content": reasoning_text if reasoning_text else None, }, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None, - "finish_reason": (finish_reason["type"] if finish_reason else ""), + "finish_reason": finish_reason["type"] if finish_reason else None, "matched_stop": ( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -1233,7 +1233,7 @@ def v1_chat_generate_response( reasoning_content=reasoning_text if reasoning_text else None, ), logprobs=choice_logprobs, - finish_reason=(finish_reason["type"] if finish_reason else ""), + finish_reason=finish_reason["type"] if finish_reason else None, matched_stop=( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -1377,23 +1377,11 @@ async def v1_chat_completions( if is_first: # First chunk with role is_first = False - if ( - tokenizer_manager.server_args.reasoning_parser - and request.separate_reasoning - ): - delta = DeltaMessage( - role="assistant", reasoning_content=None - ) - else: - delta = DeltaMessage(role="assistant", content=None) + delta = DeltaMessage(role="assistant") choice_data = ChatCompletionResponseStreamChoice( index=index, delta=delta, - finish_reason=( - None - if finish_reason_type and len(finish_reason_type) == 0 - else finish_reason_type - ), + finish_reason=finish_reason_type, matched_stop=( finish_reason["matched"] if finish_reason and "matched" in finish_reason @@ -1434,12 +1422,7 @@ async def v1_chat_completions( reasoning_text if reasoning_text else None ) ), - finish_reason=( - None - if finish_reason_type - and len(finish_reason_type) == 0 - else finish_reason_type - ), + finish_reason=finish_reason_type, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -1471,12 +1454,7 @@ async def v1_chat_completions( delta=DeltaMessage( content=normal_text if normal_text else None ), - finish_reason=( - None - if finish_reason_type - and len(finish_reason_type) == 0 - else finish_reason_type - ), + finish_reason=finish_reason_type, ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -1490,11 +1468,7 @@ async def v1_chat_completions( for call_item in calls: # transform call_item -> FunctionResponse + ToolCall - if ( - content["meta_info"]["finish_reason"] - and content["meta_info"]["finish_reason"]["type"] - == "stop" - ): + if finish_reason_type == "stop": latest_delta_len = 0 if isinstance(call_item.parameters, str): latest_delta_len = len(call_item.parameters) @@ -1515,6 +1489,8 @@ async def v1_chat_completions( ) call_item.parameters = remaining_call + finish_reason_type = "tool_calls" + tool_call = ToolCall( id=str(call_item.tool_index), function=FunctionResponse( @@ -1524,10 +1500,13 @@ async def v1_chat_completions( ) choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage( - role="assistant", tool_calls=[tool_call] - ), - finish_reason="tool_call", + delta=DeltaMessage(tool_calls=[tool_call]), + finish_reason=( + None + if request.stream_options + and request.stream_options.include_usage + else finish_reason_type + ), # additional chunk will be return ) chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -1542,30 +1521,44 @@ async def v1_chat_completions( else: # No tool calls => just treat this as normal text - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta if delta else None), - finish_reason=( - None - if finish_reason_type and len(finish_reason_type) == 0 - else finish_reason_type - ), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - stream_buffers[index] = new_stream_buffer - is_firsts[index] = is_first + if delta or not ( + request.stream_options + and request.stream_options.include_usage + ): + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta if delta else None), + finish_reason=( + None + if request.stream_options + and request.stream_options.include_usage + else finish_reason_type + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + if finish_reason_type == "stop" and request.tool_choice != "none": + parser = FunctionCallParser( + tools=request.tools, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + ) + if parser.has_tool_call(new_stream_buffer): + # if the stream ends with empty string after tool calls + finish_reason_type = "tool_calls" + if request.stream_options and request.stream_options.include_usage: total_prompt_tokens = sum( tokens @@ -1590,17 +1583,22 @@ async def v1_chat_completions( prompt_tokens_details=prompt_tokens_details, ) - final_usage_chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[], - model=request.model, - usage=usage, - ) - final_usage_data = final_usage_chunk.model_dump_json( - exclude_none=True - ) - yield f"data: {final_usage_data}\n\n" + else: + usage = None + final_usage_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + usage=usage, + ) + yield f"data: {final_usage_chunk.model_dump_json()}\n\n" except ValueError as e: error = create_streaming_error_response(str(e)) yield f"data: {error}\n\n" diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index c4b89c870..60a01766d 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -187,7 +187,7 @@ class CompletionResponseChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + finish_reason: Literal["stop", "length", "content_filter"] matched_stop: Union[None, int, str] = None @@ -204,7 +204,7 @@ class CompletionResponseStreamChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None matched_stop: Union[None, int, str] = None @@ -387,7 +387,9 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: str + finish_reason: Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call" + ] matched_stop: Union[None, int, str] = None @@ -411,7 +413,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Optional[str] = None + finish_reason: Optional[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + ] = None matched_stop: Union[None, int, str] = None diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index e18168eb7..af2ed20e9 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -258,7 +258,12 @@ class TestOpenAIServer(CustomTestCase): ret_num_top_logprobs == logprobs ), f"{ret_num_top_logprobs} vs {logprobs}" - assert isinstance(data.content, str) or response.choices[0].finish_reason + assert ( + isinstance(data.content, str) + or isinstance(data.reasoning_content, str) + or len(data.tool_calls) > 0 + or response.choices[0].finish_reason + ) assert response.id assert response.created