[dev]add glm4.7 tool-parser (#151)

Signed-off-by: zhangzhenyi <zhangzhenyi@baidu.com>
Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
astrophel0
2026-01-30 15:24:14 +08:00
committed by root
parent e28b697458
commit 726cefb7a3
2 changed files with 1860 additions and 0 deletions

View File

@@ -0,0 +1,948 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
import jinja2
import partial_json_parser
import regex as re
from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage,
random_tool_call_id)
from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
parse_chat_output, render_for_completion)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params)
from vllm.utils import as_list
logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
num_cached_tokens = None
if self.use_harmony:
harmony_parsers = [
get_streamable_parser_for_assistant()
for _ in range(num_choices)
]
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned = [False] * num_choices
# Always track previous_texts for comprehensive output logging
previous_texts = [""] * num_choices
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or self.reasoning_parser:
# These are only required in "auto" tool choice case
all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
all_previous_token_ids = None
else:
all_previous_token_ids = None
enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True
try:
if self.reasoning_parser:
reasoning_parser = self.reasoning_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
tool_parsers: list[Optional[ToolParser]] = [
self.tool_parser(tokenizer)
] * num_choices
else:
tool_parsers = [None] * num_choices
except Exception as e:
logger.exception("Error in tool parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage \
or enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
num_cached_tokens = res.num_cached_tokens
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
role=role,
content="",
),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# if continuous usage stats are requested, add it
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content: Union[str, list[dict[str, str]]] = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
if last_msg_content:
for i in range(num_choices):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]:
continue
if request.logprobs and request.top_logprobs is not None:
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
return_as_token_id=request.
return_tokens_as_token_ids,
)
else:
logprobs = None
if self.use_harmony:
harmony_parser = harmony_parsers[i]
for token_id in output.token_ids:
harmony_parser.process(token_id)
# FIXME(woosuk): Support function calling
is_final = harmony_parser.current_channel == "final"
if not (request.include_reasoning or is_final):
# Skip the reasoning content.
continue
delta_text = harmony_parser.last_content_delta or ""
else:
delta_text = output.text
if not delta_text and not output.token_ids and \
not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
continue
delta_message: Optional[DeltaMessage]
# just update previous_texts and previous_token_ids
if ((tool_choice_auto or self.reasoning_parser)
and not self.use_harmony):
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
# avoid the None + list error.
if previous_token_ids:
current_token_ids = previous_token_ids + as_list(
output.token_ids)
else:
current_token_ids = as_list(output.token_ids)
if self.use_harmony:
if is_final:
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = DeltaMessage(
reasoning_content=delta_text)
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end(
previous_token_ids)):
assert reasoning_parser is not None
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# When encountering think end id in delta_token_ids
# or think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# set reasoning status to end.
# Only keep 'content', remove 'reasoning_content'.
if reasoning_parser.is_reasoning_end(
as_list(output.token_ids)) or (
res.prompt_token_ids
and reasoning_parser.is_reasoning_end(
res.prompt_token_ids)):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
# This need to be added to next `delta_text`
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
else:
# Just to add remaining `content`
if self.reasoning_parser:
delta_text = previous_text + delta_text
current_text = ""
if function_name_returned[i]:
delta_tool_call = DeltaToolCall(
function=DeltaFunctionCall(
arguments=delta_text),
index=i)
else:
delta_tool_call = DeltaToolCall(
id=random_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
function_name_returned[i] = True
delta_message = DeltaMessage(tool_calls=[
delta_tool_call,
])
elif request.tool_choice == "required":
assert previous_texts is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
if self.reasoning_parser:
_, content = \
reasoning_parser.extract_reasoning_content(
current_text,
request
)
else:
content = current_text
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=content,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.reasoning_parser:
assert tool_parser is not None
assert reasoning_parser is not None
assert added_content_delta_arr is not None
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
if not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
))
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if not enable_thinking:
reasoning_end_arr[i] = True
current_token_ids = output_token_ids
if delta_message and delta_message.reasoning_content:
current_text = delta_message.reasoning_content
delta_message.content = None
delta_message.reasoning_content = None
else:
current_text = delta_message.content
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if reasoning_parser.is_reasoning_end(
output_token_ids):
reasoning_end_arr[i] = True
current_token_ids = \
reasoning_parser.extract_content_ids(
output_token_ids)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# handle tool calls only after reasoning is done,
else:
delta_token_ids = output_token_ids
# First time to tool call,
# add the remaining text and token ids
# to delta from previous
if not added_content_delta_arr[i]:
added_content_delta_arr[i] = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request))
# when only tool calls
elif tool_choice_auto:
assert tool_parser is not None
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids,
request=request))
# when only reasoning
elif self.reasoning_parser and enable_thinking:
delta_message = (reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration
if tool_choice_auto or self.reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
else:
# Update for comprehensive logging even in simple case
assert previous_texts is not None
previous_texts[i] += delta_text
# set the previous values for the next iteration
previous_num_tokens[i] += len(output.token_ids)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
continue
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
delta_content = ""
if delta_message.content:
delta_content = delta_message.content
elif delta_message.tool_calls:
delta_content = "".join(
tc.function.arguments
for tc in delta_message.tool_calls
if tc.function and tc.function.arguments)
if delta_content:
self.request_logger.log_outputs(
request_id=request_id,
outputs=delta_content,
output_token_ids=as_list(output.token_ids),
finish_reason=output.finish_reason,
is_streaming=True,
delta=True,
)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
# if the model is finished generating
else:
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
auto_tools_called = False
if tool_parser:
auto_tools_called = len(
tool_parser.prev_tool_call_arr) > 0
index = len(tool_parser.prev_tool_call_arr
) - 1 if auto_tools_called else 0
else:
index = 0
if self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output) and tool_parser:
latest_delta_len = 0
if ((isinstance(
delta_message.tool_calls[0].function,
DeltaFunctionCall)) and isinstance(
delta_message.tool_calls[0].function.
arguments, str)):
latest_delta_len = len(
delta_message.tool_calls[0].function.
arguments)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}),
ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
if (latest_delta_len > 0):
actual_call = actual_call[:-latest_delta_len]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(
actual_call, "", 1)
# set that as a delta message
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(index=index,
function=DeltaFunctionCall(
arguments=remaining_call).
model_dump(exclude_none=True))
])
# Send the finish response for each request.n only once
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens,
)
# Log complete streaming response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
# Log the complete response for each choice
for i in range(num_choices):
full_text = (
previous_texts[i]
if previous_texts and i < len(previous_texts) else
f"<streaming_complete: {previous_num_tokens[i]} tokens>"
)
self.request_logger.log_outputs(
request_id=request_id,
outputs=full_text,
output_token_ids=
None, # Consider also logging all token IDs
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert final_res is not None
choices: list[ChatCompletionResponseChoice] = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
return_as_token_id=request.return_tokens_as_token_ids,
)
else:
logprobs = None
if self.use_harmony:
reasoning_content, final_content, is_tool_call = (
parse_chat_output(token_ids))
if not request.include_reasoning:
reasoning_content = None
if is_tool_call:
# TODO(woosuk): Implement tool call for gpt-oss.
# For now, only Responses API supports tool call for
# gpt-oss.
raise NotImplementedError(
"Tool call in Chat Completion API is not supported "
"for gpt-oss yet. Please use Responses API instead.")
else:
# Normal message
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=final_content,
)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if is_tool_call else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason,
)
choices.append(choice_data)
continue
enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True
if self.reasoning_parser and enable_thinking:
try:
reasoning_parser = self.reasoning_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(
output.text, request=request))
if not request.include_reasoning:
reasoning_content = None
else:
reasoning_content = None
content = output.text
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools or not self.tool_parser) and \
(not isinstance(request.tool_choice,
ChatCompletionNamedToolChoiceParam
) and request.tool_choice != "required"):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# if the request uses tools and specified a tool choice
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=content,
))
],
)
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
assert content is not None
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(content)
message = ChatMessage(
role=role,
content="",
reasoning_content=reasoning_content,
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters,
ensure_ascii=False)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
# handle when there are tools and tool choice is auto
elif request.tools and (
request.tool_choice == "auto"
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "", request=request)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls)
else:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
ret_content = content
# try to use content return from tool parser first,
# tool parser may do some modify for the content.
if (tool_call_info.content
and len(tool_call_info.content) > 0):
ret_content = tool_call_info.content
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=ret_content)
# undetermined case that is still important to handle
else:
logger.error(
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion.")
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo:
last_msg_content: Union[str, list[dict[str, str]]] = ""
if (conversation and "content" in conversation[-1]
and conversation[-1].get("role") == role):
last_msg_content = conversation[-1]["content"] or ""
if isinstance(last_msg_content, list):
last_msg_content = "\n".join(msg['text']
for msg in last_msg_content)
for choice in choices:
full_message = last_msg_content + (choice.message.content
or "")
choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
output_text = ""
if choice.message.content:
output_text = choice.message.content
elif choice.message.tool_calls:
# For tool calls, log the function name and arguments
tool_call_descriptions = []
for tool_call in choice.message.tool_calls:
if hasattr(tool_call.function, "name") and hasattr(
tool_call.function, "arguments"):
tool_call_descriptions.append(
f"{tool_call.function.name}({tool_call.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]"
if output_text:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[
choice.index].token_ids
self.request_logger.log_outputs(
request_id=request_id,
outputs=output_text,
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response

View File

@@ -0,0 +1,912 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from functools import partial
from importlib.resources import contents
import json
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from enum import Enum
from vllm.utils import random_uuid
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class StreamState(str, Enum):
"""State machine states for XML to JSON streaming conversion."""
INIT = "INIT"
BETWEEN = "BETWEEN"
IN_KEY = "IN_KEY"
WAITING_VALUE = "WAITING_VALUE"
IN_VALUE = "IN_VALUE"
def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"
def get_argument_type(
func_name: str, arg_key: str, defined_tools: list[ChatCompletionToolsParam]
) -> Optional[str]:
"""Get the expected type of a function argument from tool definitions.
Supports complex JSON Schema definitions including:
- Direct type field (including type arrays)
- anyOf/oneOf: parameter can be any of multiple types
- enum: parameter must be one of enum values
- allOf: parameter must satisfy all type definitions
- properties: inferred as object type
- items: inferred as array type
Args:
func_name: Name of the function/tool
arg_key: Name of the argument
defined_tools: List of available tools
Returns:
The type string (e.g., 'string', 'number', 'object') or None if not found
"""
name2tool = {tool.function.name: tool for tool in defined_tools}
# Check if function exists
tool = name2tool.get(func_name)
if not tool:
return None
# Get parameters safely using getattr
params = getattr(tool.function, "parameters", None)
if not isinstance(params, dict):
return None
# Navigate to the type using dict.get() for safe access
properties = params.get("properties")
if not isinstance(properties, dict):
return None
arg_spec = properties.get(arg_key)
if isinstance(arg_spec, dict):
# Use the new type inference function for complex JSON Schema support
return infer_type_from_json_schema(arg_spec)
return None
def _convert_to_number(value: str) -> Any:
"""Convert string to appropriate number type (int or float).
Args:
value: String value to convert
Returns:
Converted number or original string if conversion fails
"""
try:
if "." in value or "e" in value.lower():
return float(value)
else:
return int(value)
except (ValueError, AttributeError):
return value
def parse_arguments(
json_value: str, arg_type: Optional[str] = None
) -> tuple[Any, bool]:
"""Parse argument value with multiple fallback strategies.
Args:
json_value: Raw string value to parse
arg_type: Expected type hint ('string', 'number', 'object', etc.)
Returns:
Tuple of (parsed_value, is_valid_json)
"""
# Strategy 1: Direct JSON parsing
try:
parsed_value = json.loads(json_value)
# Type coercion for number type
if arg_type == "number" and isinstance(parsed_value, str):
parsed_value = _convert_to_number(parsed_value)
return parsed_value, True
except (json.JSONDecodeError, ValueError):
pass
# Strategy 2: Unescape and parse
try:
wrapped = json.loads('{"tmp": "' + json_value + '"}')
parsed_value = json.loads(wrapped["tmp"])
if arg_type == "number" and isinstance(parsed_value, str):
parsed_value = _convert_to_number(parsed_value)
return parsed_value, True
except (json.JSONDecodeError, ValueError, KeyError):
pass
# Strategy 3: ast.literal_eval
try:
parsed_value = ast.literal_eval(json_value)
return parsed_value, True
except (ValueError, SyntaxError):
pass
# Strategy 4: Treat as string
try:
quoted_value = json.dumps(str(json_value))
return json.loads(quoted_value), True
except (json.JSONDecodeError, ValueError):
return json_value, False
def infer_type_from_json_schema(schema: dict[str, Any]) -> Optional[str]:
"""
Infer the primary type of a parameter from JSON Schema.
Supports complex JSON Schema structures including:
- Direct type field (including type arrays)
- anyOf/oneOf: parameter can be any of multiple types
- enum: parameter must be one of enum values
- allOf: parameter must satisfy all type definitions
- properties: inferred as object type
- items: inferred as array type
Args:
schema: JSON Schema definition
Returns:
Inferred type ('string', 'number', 'object', 'array', etc.) or None
"""
if not isinstance(schema, dict):
return None
# Priority 1: Direct type field (including type arrays)
if "type" in schema:
type_value = schema["type"]
if isinstance(type_value, str):
return type_value
elif isinstance(type_value, list) and type_value:
# Handle type arrays: return first non-null type
non_null_types = [t for t in type_value if t != "null"]
if non_null_types:
return non_null_types[0]
return "string" # If only null, default to string
# Priority 2: Handle anyOf/oneOf
if "anyOf" in schema or "oneOf" in schema:
schemas = schema.get("anyOf") or schema.get("oneOf")
types = []
if isinstance(schemas, list):
for sub_schema in schemas:
inferred_type = infer_type_from_json_schema(sub_schema)
if inferred_type:
types.append(inferred_type)
if types:
# If all types are the same, return unified type
if len(set(types)) == 1:
return types[0]
# When types differ, prioritize string (safest)
if "string" in types:
return "string"
# Otherwise return first type
return types[0]
# Priority 3: Handle enum (infer type from enum values)
if "enum" in schema and isinstance(schema["enum"], list):
if not schema["enum"]:
return "string"
# Infer type from enum values
enum_types = set()
for value in schema["enum"]:
if value is None:
enum_types.add("null")
elif isinstance(value, bool):
enum_types.add("boolean")
elif isinstance(value, int):
enum_types.add("integer")
elif isinstance(value, float):
enum_types.add("number")
elif isinstance(value, str):
enum_types.add("string")
elif isinstance(value, list):
enum_types.add("array")
elif isinstance(value, dict):
enum_types.add("object")
# If type is uniform, return that type
if len(enum_types) == 1:
return enum_types.pop()
# Mixed types, prioritize string
return "string"
# Priority 4: Handle allOf (must satisfy all types)
if "allOf" in schema and isinstance(schema["allOf"], list):
schemas = schema["allOf"]
for sub_schema in schemas:
inferred_type = infer_type_from_json_schema(sub_schema)
if inferred_type and inferred_type != "string":
return inferred_type
return "string"
# Priority 5: Infer object type
if "properties" in schema:
return "object"
# Priority 6: Infer array type
if "items" in schema:
return "array"
return None
@ToolParserManager.register_module("glm47")
class Glm47MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
self._tool_indices = 0
self._last_arguments: str = ""
self._streamed_raw_length = 0
self.tool_calls_start_token = self.tool_call_start_token
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>",
re.DOTALL)
self.func_detail_regex = re.compile(
r"<tool_call>([^\n<]*)\n?(.*)</tool_call>", re.DOTALL)
self.func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self._buffer = ""
self._reset_streaming_state()
def _reset_streaming_state(self) -> None:
"""Reset the streaming state machine for a new tool call."""
self._stream_state = StreamState.INIT
self._current_key = ""
self._current_value = ""
self._xml_tag_buffer = ""
self._is_first_param = True
self._value_started = False
self._cached_value_type: Optional[str] = (
None # Cache the value type for consistency
)
self._tool_call_completed = False # Reset tool call completion status
self._sent_empty_object = False # Reset empty object sent status
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
def _is_string_type(
tool_name: str, arg_name: str,
tools: Optional[list[ChatCompletionToolsParam]]) -> bool:
if tools is None:
return False
for tool in tools:
if tool.function.name == tool_name:
if tool.function.parameters is None:
return False
arg_type = tool.function.parameters.get(
"properties", {}).get(arg_name, {}).get("type", None)
return arg_type == "string"
logger.warning("No tool named '%s'.", tool_name)
return False
def _deserialize(value: str) -> Any:
try:
return json.loads(value)
except Exception:
pass
try:
return ast.literal_eval(value)
except Exception:
pass
return value
matched_tool_calls = self.func_call_regex.findall(model_output)
logger.debug("model_output: %s", model_output)
try:
tool_calls = []
for match in matched_tool_calls:
tc_detail = self.func_detail_regex.search(match)
tc_name = tc_detail.group(1)
tc_args = tc_detail.group(2)
pairs = self.func_arg_regex.findall(tc_args)
arg_dct = {}
for key, value in pairs:
arg_key = key.strip()
arg_val = value.strip()
if not _is_string_type(tc_name, arg_key, request.tools):
arg_val = _deserialize(arg_val)
logger.debug("arg_key = %s, arg_val = %s", arg_key,
arg_val)
arg_dct[arg_key] = arg_val
tool_calls.append(
ToolCall(type="function",
function=FunctionCall(
name=tc_name, arguments=json.dumps(arg_dct))))
except Exception:
logger.exception("Failed to extract tool call spec")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
if len(tool_calls) > 0:
content = model_output[:model_output.
find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=content)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]:
"""Extract function name, arguments and end marker from regex match.
Args:
match: Regex match object
Returns:
(func_name, func_args_raw, is_tool_end)
"""
func_name = match.group(1).strip()
func_args_raw = match.group(2).strip() if match.group(2) else ""
is_tool_end = match.group(3) or ""
return func_name, func_args_raw, is_tool_end
def _send_tool_name_if_needed(
self, func_name: str, has_arg_key: bool, is_tool_end: str
) -> Optional[DeltaToolCall]:
"""Send tool name if needed.
Args:
func_name: Function name
has_arg_key: Whether current text contains <arg_key
is_tool_end: Whether end marker is encountered
Returns:
Tool call item or None
"""
if self.current_tool_name_sent:
return None
# Function name completeness check
is_func_name_complete = has_arg_key or is_tool_end == self.tool_call_end_token
if not is_func_name_complete:
return None
if not func_name:
logger.warning("Empty function name detected, skipping tool call")
return None
# Send tool name
self.current_tool_name_sent = True
self._streamed_raw_length = 0
self._reset_streaming_state()
# Record tool info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
return DeltaToolCall(
id=random_tool_call_id(),
index=self.current_tool_id,
type="function",
function=DeltaFunctionCall(name=func_name, arguments=""),
)
def _parse_argument_pairs(
self, pairs: list[tuple[str, str]], func_name: str, tools: list[ChatCompletionToolsParam]
) -> dict[str, Any]:
"""Parse argument key-value pairs with type coercion.
Args:
pairs: List of (key, value) tuples from regex matching
func_name: Name of the function
tools: List of available tools
Returns:
Dictionary of parsed arguments
"""
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()
arg_type = get_argument_type(func_name, arg_key, tools)
parsed_value, is_good_json = parse_arguments(arg_value, arg_type)
if arg_type == "string":
# Only convert to string if explicitly defined as string type
if isinstance(parsed_value, str):
arguments[arg_key] = parsed_value
elif isinstance(parsed_value, (dict, list)):
# If parsed as dict/list but schema says string, convert to JSON string
arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)
else:
arguments[arg_key] = str(parsed_value)
elif arg_type is None:
# If type is not defined, keep the parsed value as-is
arguments[arg_key] = parsed_value if is_good_json else arg_value
else:
# For other types (number, object, array, etc.), use parsed value
arguments[arg_key] = parsed_value if is_good_json else arg_value
return arguments
def _finalize_tool_call(
self,
func_name: str,
func_args_raw: str,
tools: list[ChatCompletionToolsParam],
match_end_pos: int,
current_text: str,
) -> list[DeltaToolCall]:
"""Complete tool call processing.
Args:
func_name: Function name
func_args_raw: Raw argument string
tools: List of available tools
match_end_pos: Match end position
current_text: Current text
Returns:
List of tool call items to add
"""
calls = []
# Handle no-arg function or need to close braces
if self._is_first_param and not self._sent_empty_object:
# No-arg function
calls.append(
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(name=None, arguments="{}"),
)
)
self._last_arguments += "{}"
self.streamed_args_for_tool[self.current_tool_id] += "{}"
self._sent_empty_object = True
elif not self._last_arguments.endswith("}") and not self._sent_empty_object:
# Need to close brace
calls.append(
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(name=None, arguments="}"),
)
)
self._last_arguments += "}"
self.streamed_args_for_tool[self.current_tool_id] += "}"
self._sent_empty_object = True
# Parse final arguments
if func_args_raw:
try:
pairs = self.func_arg_regex.findall(func_args_raw)
if pairs:
arguments = self._parse_argument_pairs(pairs, func_name, tools)
self.prev_tool_call_arr[self.current_tool_id][
"arguments"
] = arguments
except Exception as e:
logger.debug(f"Failed to parse arguments: {e}", exc_info=True)
# Clean buffer
self._buffer = current_text[match_end_pos:]
# Reset state for next tool call
self._tool_call_completed = True
self.current_tool_id += 1
self._last_arguments = ""
self.current_tool_name_sent = False
self._streamed_raw_length = 0
self._reset_streaming_state()
return calls
def _format_value_complete(self, value: str, value_type: str) -> str:
"""Format complete value based on type.
Args:
value: Raw value string
value_type: Expected type ('string', 'number', 'object')
Returns:
Properly formatted JSON value string
"""
if value_type == "string":
# Ensure proper JSON string formatting with quotes
return json.dumps(value, ensure_ascii=False)
elif value_type == "number":
try:
num = _convert_to_number(value.strip() if value else "")
return str(num)
except (ValueError, AttributeError):
# Fallback to string if not a valid number
logger.warning(
f"Failed to parse '{value}' as number, treating as string"
)
return json.dumps(str(value) if value else "", ensure_ascii=False)
else:
# For object/array types, return as-is (should already be valid JSON)
return value
def _process_xml_to_json_streaming(
self, raw_increment: str, func_name: str, tools: list[ChatCompletionToolsParam]
) -> str:
"""Convert XML increment to JSON streaming output using state machine.
This method processes XML fragments character by character and converts them
to JSON format incrementally. It maintains state across calls to handle
partial XML tags and values.
Args:
raw_increment: New XML content to process
func_name: Name of the function being called
tools: List of available tools for type inference
Returns:
JSON string increment to append to the output
"""
json_output = ""
for char in raw_increment:
self._xml_tag_buffer += char
if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:
if self._xml_tag_buffer.endswith("<arg_key>"):
self._stream_state = StreamState.IN_KEY
self._current_key = ""
self._xml_tag_buffer = ""
json_output += "{" if self._is_first_param else ", "
self._is_first_param = False
elif self._stream_state == StreamState.IN_KEY:
if self._xml_tag_buffer.endswith("</arg_key>"):
self._current_key = self._xml_tag_buffer[:-10].strip()
self._xml_tag_buffer = ""
self._stream_state = StreamState.WAITING_VALUE
json_output += (
json.dumps(self._current_key, ensure_ascii=False) + ": "
)
elif self._stream_state == StreamState.WAITING_VALUE:
if self._xml_tag_buffer.endswith("<arg_value>"):
self._stream_state = StreamState.IN_VALUE
self._current_value = ""
self._xml_tag_buffer = ""
self._value_started = False
# Determine and cache the value type at the start
self._cached_value_type = self._get_value_type(
func_name, self._current_key, tools
)
elif self._stream_state == StreamState.IN_VALUE:
if self._xml_tag_buffer.endswith("</arg_value>"):
final_value = self._xml_tag_buffer[:-12]
self._current_value += final_value
# Use cached value type for consistency
value_type = self._cached_value_type or "string"
if self._value_started:
# Output any remaining content
if final_value:
if value_type == "string":
json_output += json.dumps(
final_value, ensure_ascii=False
)[1:-1]
else:
json_output += final_value
# Always output closing quote for string type when value was started
if value_type == "string":
json_output += '"'
else:
# Value was never started (empty or complete in one chunk)
json_output += self._format_value_complete(
self._current_value, value_type
)
self._xml_tag_buffer = ""
self._stream_state = StreamState.BETWEEN
self._current_value = ""
self._value_started = False
self._cached_value_type = None # Reset cached type
else:
closing_tag = "</arg_value>"
is_potential_closing = len(self._xml_tag_buffer) <= len(
closing_tag
) and closing_tag.startswith(self._xml_tag_buffer)
if not is_potential_closing:
content = self._xml_tag_buffer
# Use cached value type for consistency
value_type = self._cached_value_type or "string"
if value_type == "string":
if not self._value_started:
json_output += '"'
self._value_started = True
if content:
json_output += json.dumps(content, ensure_ascii=False)[
1:-1
]
self._current_value += content
self._xml_tag_buffer = ""
elif value_type == "number":
if content:
if not self._value_started:
self._value_started = True
json_output += content
self._current_value += content
self._xml_tag_buffer = ""
else:
# For object/array types, output as-is
if content:
if not self._value_started:
self._value_started = True
json_output += content
self._current_value += content
self._xml_tag_buffer = ""
return json_output
def _get_value_type(self, func_name: str, key: str, tools: list[ChatCompletionToolsParam]) -> str:
"""Get parameter type from tool definition, with fallback to auto-detection.
Args:
func_name: Name of the function
key: Parameter name
tools: List of available tools
Returns:
Type string: 'string', 'number', 'object', 'array', or 'boolean'
"""
arg_type = get_argument_type(func_name, key, tools)
if arg_type:
return arg_type
# Improved auto-detection type from value (best effort)
value_content = self._current_value.strip() if self._current_value else ""
if not value_content:
return "string"
# Try to parse as valid JSON first
try:
parsed = json.loads(value_content)
if isinstance(parsed, dict):
return "object"
elif isinstance(parsed, list):
return "array"
elif isinstance(parsed, bool):
return "boolean"
elif isinstance(parsed, (int, float)):
return "number"
# For string values, check if they look like numbers
elif isinstance(parsed, str):
if parsed.isdigit() or (
parsed.startswith("-") and parsed[1:].isdigit()
):
return "number"
return "string"
except json.JSONDecodeError:
# Not valid JSON, try heuristic detection
first_char = value_content[0] if value_content else ""
if first_char.isdigit() or first_char in ["-", "."]:
return "number"
elif first_char in ["{", "["]:
return "object"
elif first_char in ['"', "'"]:
return "string"
# Default to string (safest fallback)
return "string"
def _process_arguments_streaming(
self, func_name: str, func_args_raw: str, tools: list[ChatCompletionToolsParam]
) -> Optional[DeltaToolCall]:
"""Process streaming arguments.
Args:
func_name: Function name
func_args_raw: Raw argument string
tools: List of available tools
Returns:
Tool call item with parameter updates or None
"""
current_raw_length = len(func_args_raw)
if current_raw_length <= self._streamed_raw_length:
return None
# Get new raw XML content
raw_increment = func_args_raw[self._streamed_raw_length :]
# Convert XML to JSON using state machine
json_increment = self._process_xml_to_json_streaming(
raw_increment, func_name, tools
)
# CRITICAL: Update streamed length BEFORE early return
# Even if json_increment is empty, the input has been consumed by the state machine
self._streamed_raw_length = current_raw_length
if not json_increment:
return None
# Update state
self._last_arguments += json_increment
self.streamed_args_for_tool[self.current_tool_id] += json_increment
return DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(name=None, arguments=json_increment),
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
self._buffer += delta_text
current_text = self._buffer
# Check if we have a tool call
has_tool_call = self.tool_call_start_token in current_text
if not has_tool_call:
# Check if buffer could be the start of a tool call
# Keep buffer if it could be a partial match of bot_token
is_potential_start = any(
self.tool_call_start_token.startswith(current_text[-i:])
for i in range(1, min(len(current_text), len(self.tool_call_start_token)) + 1)
)
if not is_potential_start:
# Not a potential tool call, return as normal text
# Must return the entire buffer (current_text), not just new_text,
# because buffer may contain previously accumulated characters like '<'
# that turned out not to be part of a tool call
output_text = current_text
self._buffer = ""
if self.tool_call_end_token in output_text:
output_text = output_text.replace(self.tool_call_end_token, "")
return DeltaMessage(content=output_text)
else:
# Could be start of tool call, keep buffering
return None
# Extract any text before the first bot_token and return it as normal_text
output_text = ""
first_bot_token_idx = current_text.find(self.tool_call_start_token)
if first_bot_token_idx > 0:
output_text= current_text[:first_bot_token_idx]
current_text = current_text[first_bot_token_idx:]
# Update buffer to only include from the bot token onwards
self._buffer = current_text
if not hasattr(self, "_tool_indices"):
self._tool_indices += 1
calls: list[DeltaToolCall] = []
try:
# Try to match a partial or complete tool call
# Use a single flexible regex pattern that handles all cases
partial_match = re.search(
r"<tool_call>(.*?)(?:(<arg_key.*?))?(?:(</tool_call>)|$)",
current_text,
re.DOTALL,
)
if not partial_match:
return None
# return DeltaMessage(content=output_text, tool_calls=[])
# Extract match groups using helper method
func_name, func_args_raw, is_tool_end = self._extract_match_groups(
match=partial_match
)
# Initialize tool call state if needed (keeping existing logic)
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
self._streamed_raw_length = 0
self.current_tool_name_sent = False # Reset for new tool call
self._reset_streaming_state()
# Check if this is a continuation of an existing tool call or a new one
elif not self.current_tool_name_sent:
# Only increment tool_id if we're truly starting a NEW tool call
# Don't increment if this is just the first time we're processing
# a tool call that was received in the buffer
# The key insight: only increment when we've COMPLETED a previous tool call
# and now see another bot_token in new_text
pass # Remove the problematic auto-increment logic
# Ensure tracking arrays are large enough (keeping existing logic)
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Determine if function name is complete by checking for <arg_key> in the full text
# This is important for streaming scenarios where args come in later chunks
has_arg_key = "<arg_key" in current_text
# Send tool name if needed
tool_name_item = self._send_tool_name_if_needed(
func_name, has_arg_key, is_tool_end
)
if tool_name_item:
calls.append(tool_name_item)
# Process streaming arguments if tool name has been sent
if self.current_tool_name_sent and request.tools:
arg_item = self._process_arguments_streaming(
func_name, func_args_raw, request.tools
)
if arg_item:
calls.append(arg_item)
# Finalize tool call if end token is encountered
if is_tool_end == self.tool_call_end_token and not self._tool_call_completed:
finalize_calls = self._finalize_tool_call(
func_name,
func_args_raw,
request.tools,
partial_match.end(),
current_text,
)
calls.extend(finalize_calls)
return DeltaMessage(content=output_text, tool_calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True)
return DeltaMessage(content=output_text)
# Only return if we have meaningful content or tool calls to avoid empty chunks
if output_text.strip() or calls:
return DeltaMessage(content=output_text, tool_calls=calls)
else:
return None