bugfix: Fix multiple finish_reason chunks and tool_calls finish reason check (#8417)

This commit is contained in:
Chang Su
2025-07-27 13:31:06 -07:00
committed by GitHub
parent e983d66680
commit b47eda3316
4 changed files with 500 additions and 235 deletions

View File

@@ -412,6 +412,8 @@ class OpenAIServingChat(OpenAIServingBase):
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
has_tool_calls = {}
finish_reasons = {}
# Usage tracking
prompt_tokens = {}
@@ -443,6 +445,10 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
# Track finish_reason for each index
if finish_reason_type:
finish_reasons[index] = finish_reason
# First chunk with role
if is_firsts.get(index, True):
is_firsts[index] = False
@@ -450,13 +456,8 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
finish_reason=None,
logprobs=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
@@ -483,7 +484,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=finish_reason_type,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
@@ -495,40 +496,34 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls
if request.tool_choice != "none" and request.tools:
async for (
chunk,
tool_call_finish_reason_type,
) in self._process_tool_call_stream(
async for chunk in self._process_tool_call_stream(
index,
delta,
parser_dict,
content,
request,
finish_reason_type,
has_tool_calls,
):
if chunk:
yield chunk
finish_reason_type = tool_call_finish_reason_type
# Send any remaining tool call arguments when generation finishes
if finish_reason_type is not None and index in parser_dict:
parser = parser_dict[index]
remaining_chunk = self._check_for_unstreamed_tool_args(
parser, content, request, index
)
if remaining_chunk:
yield remaining_chunk
else:
# Regular content
if delta or not (
request.stream_options and request.stream_options.include_usage
):
if delta:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
@@ -539,26 +534,36 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Final chunk with finish_reason
finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send finish_reason chunks for each index that completed
for idx, finish_reason_data in finish_reasons.items():
finish_reason_type = finish_reason_data["type"]
# Change finish_reason to "tool_calls" if we had tool calls and stopped naturally
final_finish_reason = finish_reason_type
if has_tool_calls.get(idx, False) and finish_reason_type == "stop":
final_finish_reason = "tool_calls"
finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"][
"id"
], # NOTE: openai uses the same chatcmpl-id for all indices
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=idx,
delta=DeltaMessage(),
finish_reason=final_finish_reason,
matched_stop=(
finish_reason_data["matched"]
if "matched" in finish_reason_data
else None
),
)
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send hidden states if requested
if request.return_hidden_states and hidden_states:
@@ -578,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
finish_reason=None, # Hidden states don't need finish_reason
)
],
model=request.model,
@@ -857,7 +862,7 @@ class OpenAIServingChat(OpenAIServingBase):
parser_dict: Dict[int, FunctionCallParser],
content: Dict[str, Any],
request: ChatCompletionRequest,
finish_reason_type: Optional[str],
has_tool_calls: Dict[int, bool],
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
@@ -874,7 +879,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=normal_text),
finish_reason=finish_reason_type,
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
@@ -882,10 +887,13 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
@@ -896,23 +904,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_id = None
function_name = None
if finish_reason_type == "stop":
# Handle remaining arguments
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.detector.prev_tool_call_arr[index].get("arguments", {}),
ensure_ascii=False,
)
actual_call = parser.detector.streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(actual_call, "", 1)
call_item.parameters = remaining_call
finish_reason_type = "tool_calls"
tool_call = ToolCall(
id=tool_call_id,
index=call_item.tool_index,
@@ -925,11 +916,7 @@ class OpenAIServingChat(OpenAIServingBase):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=(
None
if request.stream_options and request.stream_options.include_usage
else finish_reason_type
),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
@@ -937,7 +924,76 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type
yield f"data: {chunk.model_dump_json()}\n\n"
if finish_reason_type == "stop":
yield None, "tool_calls"
def _check_for_unstreamed_tool_args(
self,
parser: FunctionCallParser,
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
) -> Optional[str]:
"""
Check for any remaining tool call arguments that need to be streamed
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Only check if we have tool calls and the parser has tracked data
if (
not hasattr(parser.detector, "prev_tool_call_arr")
or not parser.detector.prev_tool_call_arr
):
return None
if (
not hasattr(parser.detector, "streamed_args_for_tool")
or not parser.detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (
expected_call.replace(actual_call, "", 1)
if actual_call in expected_call
else ""
)
if remaining_call:
# Create tool call chunk with remaining arguments
tool_call = ToolCall(
id=None, # No ID for argument deltas
index=tool_index,
function=FunctionResponse(
name=None, # No name for argument deltas
arguments=remaining_call,
),
)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=None, # Don't send finish_reason with this chunk
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
return f"data: {chunk.model_dump_json()}\n\n"
return None