[dev]add glm4.7 tool-parser (#151)
Signed-off-by: zhangzhenyi <zhangzhenyi@baidu.com> Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
948
vllm_kunlun/entrypoints/openai/serving_chat.py
Normal file
948
vllm_kunlun/entrypoints/openai/serving_chat.py
Normal file
@@ -0,0 +1,948 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
random_tool_call_id)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||
parse_chat_output, render_for_completion)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
|
||||
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids,
|
||||
validate_request_params)
|
||||
from vllm.utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
enable_force_include_usage: bool,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
num_cached_tokens = None
|
||||
if self.use_harmony:
|
||||
harmony_parsers = [
|
||||
get_streamable_parser_for_assistant()
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned = [False] * num_choices
|
||||
|
||||
# Always track previous_texts for comprehensive output logging
|
||||
previous_texts = [""] * num_choices
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
# These are only required in "auto" tool choice case
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
# For reasoning parser and tool call all enabled
|
||||
added_content_delta_arr = [False] * num_choices
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
all_previous_token_ids = None
|
||||
|
||||
enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True
|
||||
|
||||
try:
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
# Prepare the tool parser if it's needed
|
||||
try:
|
||||
if tool_choice_auto and self.tool_parser:
|
||||
tool_parsers: list[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer)
|
||||
] * num_choices
|
||||
else:
|
||||
tool_parsers = [None] * num_choices
|
||||
except Exception as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
stream_options = request.stream_options
|
||||
if stream_options:
|
||||
include_usage = stream_options.include_usage \
|
||||
or enable_force_include_usage
|
||||
include_continuous_usage = include_usage and \
|
||||
stream_options.continuous_usage_stats
|
||||
else:
|
||||
include_usage, include_continuous_usage = False, False
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
role=role,
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# if continuous usage stats are requested, add it
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo:
|
||||
last_msg_content: Union[str, list[dict[str, str]]] = ""
|
||||
if conversation and "content" in conversation[
|
||||
-1] and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
|
||||
data = chunk.model_dump_json(
|
||||
exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
return_as_token_id=request.
|
||||
return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
# FIXME(woosuk): Support function calling
|
||||
is_final = harmony_parser.current_channel == "final"
|
||||
if not (request.include_reasoning or is_final):
|
||||
# Skip the reasoning content.
|
||||
continue
|
||||
delta_text = harmony_parser.last_content_delta or ""
|
||||
else:
|
||||
delta_text = output.text
|
||||
|
||||
if not delta_text and not output.token_ids and \
|
||||
not previous_num_tokens[i]:
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
continue
|
||||
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# just update previous_texts and previous_token_ids
|
||||
if ((tool_choice_auto or self.reasoning_parser)
|
||||
and not self.use_harmony):
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
# avoid the None + list error.
|
||||
if previous_token_ids:
|
||||
current_token_ids = previous_token_ids + as_list(
|
||||
output.token_ids)
|
||||
else:
|
||||
current_token_ids = as_list(output.token_ids)
|
||||
|
||||
if self.use_harmony:
|
||||
if is_final:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
elif tool_choice_function_name:
|
||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||
and not reasoning_parser.is_reasoning_end(
|
||||
previous_token_ids)):
|
||||
assert reasoning_parser is not None
|
||||
delta_message = (
|
||||
reasoning_parser.
|
||||
extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
output.token_ids,
|
||||
))
|
||||
# When encountering think end id in delta_token_ids
|
||||
# or think end id in prompt_token_ids
|
||||
# i.e {"enable_thinking": False},
|
||||
# set reasoning status to end.
|
||||
# Only keep 'content', remove 'reasoning_content'.
|
||||
if reasoning_parser.is_reasoning_end(
|
||||
as_list(output.token_ids)) or (
|
||||
res.prompt_token_ids
|
||||
and reasoning_parser.is_reasoning_end(
|
||||
res.prompt_token_ids)):
|
||||
reasoning_end_arr[i] = True
|
||||
if delta_message and delta_message.content:
|
||||
# This need to be added to next `delta_text`
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
else:
|
||||
# Just to add remaining `content`
|
||||
if self.reasoning_parser:
|
||||
delta_text = previous_text + delta_text
|
||||
current_text = ""
|
||||
|
||||
if function_name_returned[i]:
|
||||
delta_tool_call = DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
else:
|
||||
delta_tool_call = DeltaToolCall(
|
||||
id=random_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
function_name_returned[i] = True
|
||||
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
delta_tool_call,
|
||||
])
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
|
||||
if self.reasoning_parser:
|
||||
_, content = \
|
||||
reasoning_parser.extract_reasoning_content(
|
||||
current_text,
|
||||
request
|
||||
)
|
||||
else:
|
||||
content = current_text
|
||||
delta_message, function_name_returned[i] = (
|
||||
self.extract_tool_call_required_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=content,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
# and reasoning parser
|
||||
elif tool_choice_auto and self.reasoning_parser:
|
||||
assert tool_parser is not None
|
||||
assert reasoning_parser is not None
|
||||
assert added_content_delta_arr is not None
|
||||
assert reasoning_end_arr is not None
|
||||
output_token_ids = as_list(output.token_ids)
|
||||
if not reasoning_end_arr[i]:
|
||||
delta_message = (
|
||||
reasoning_parser.
|
||||
extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
output_token_ids,
|
||||
))
|
||||
# When encountering think end id in prompt_token_ids
|
||||
# i.e {"enable_thinking": False},
|
||||
# set reasoning status to end.
|
||||
# Remove the text and token ids related
|
||||
# to 'reasoning_content'.
|
||||
if not enable_thinking:
|
||||
reasoning_end_arr[i] = True
|
||||
current_token_ids = output_token_ids
|
||||
if delta_message and delta_message.reasoning_content:
|
||||
current_text = delta_message.reasoning_content
|
||||
delta_message.content = None
|
||||
delta_message.reasoning_content = None
|
||||
else:
|
||||
current_text = delta_message.content
|
||||
# When encountering think end id in delta_token_ids,
|
||||
# set reasoning status to end.
|
||||
# Remove the text and token ids related
|
||||
# to 'reasoning_content'.
|
||||
if reasoning_parser.is_reasoning_end(
|
||||
output_token_ids):
|
||||
reasoning_end_arr[i] = True
|
||||
current_token_ids = \
|
||||
reasoning_parser.extract_content_ids(
|
||||
output_token_ids)
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
|
||||
# handle tool calls only after reasoning is done,
|
||||
else:
|
||||
delta_token_ids = output_token_ids
|
||||
# First time to tool call,
|
||||
# add the remaining text and token ids
|
||||
# to delta from previous
|
||||
if not added_content_delta_arr[i]:
|
||||
added_content_delta_arr[i] = True
|
||||
previous_text = ""
|
||||
previous_token_ids = []
|
||||
delta_text = current_text
|
||||
delta_token_ids = current_token_ids
|
||||
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request))
|
||||
# when only tool calls
|
||||
elif tool_choice_auto:
|
||||
assert tool_parser is not None
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# when only reasoning
|
||||
elif self.reasoning_parser and enable_thinking:
|
||||
delta_message = (reasoning_parser.
|
||||
extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
output.token_ids,
|
||||
))
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# update the previous values for the next iteration
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
else:
|
||||
# Update for comprehensive logging even in simple case
|
||||
assert previous_texts is not None
|
||||
previous_texts[i] += delta_text
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
# Log streaming delta if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
delta_content = ""
|
||||
if delta_message.content:
|
||||
delta_content = delta_message.content
|
||||
elif delta_message.tool_calls:
|
||||
delta_content = "".join(
|
||||
tc.function.arguments
|
||||
for tc in delta_message.tool_calls
|
||||
if tc.function and tc.function.arguments)
|
||||
|
||||
if delta_content:
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=delta_content,
|
||||
output_token_ids=as_list(output.token_ids),
|
||||
finish_reason=output.finish_reason,
|
||||
is_streaming=True,
|
||||
delta=True,
|
||||
)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using guided decoding
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(
|
||||
tool_parser.prev_tool_call_arr) > 0
|
||||
index = len(tool_parser.prev_tool_call_arr
|
||||
) - 1 if auto_tools_called else 0
|
||||
else:
|
||||
index = 0
|
||||
|
||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output) and tool_parser:
|
||||
latest_delta_len = 0
|
||||
if ((isinstance(
|
||||
delta_message.tool_calls[0].function,
|
||||
DeltaFunctionCall)) and isinstance(
|
||||
delta_message.tool_calls[0].function.
|
||||
arguments, str)):
|
||||
latest_delta_len = len(
|
||||
delta_message.tool_calls[0].function.
|
||||
arguments)
|
||||
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}),
|
||||
ensure_ascii=False)
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
if (latest_delta_len > 0):
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1)
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage
|
||||
if include_usage:
|
||||
completion_tokens = sum(previous_num_tokens)
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
num_completion_tokens = sum(previous_num_tokens)
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||
)
|
||||
|
||||
# Log complete streaming response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
# Log the complete response for each choice
|
||||
for i in range(num_choices):
|
||||
full_text = (
|
||||
previous_texts[i]
|
||||
if previous_texts and i < len(previous_texts) else
|
||||
f"<streaming_complete: {previous_num_tokens[i]} tokens>"
|
||||
)
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=full_text,
|
||||
output_token_ids=
|
||||
None, # Consider also logging all token IDs
|
||||
finish_reason="streaming_complete",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.exception("Error in chat completion stream generator.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: list[ChatCompletionResponseChoice] = []
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
return_as_token_id=request.return_tokens_as_token_ids,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, final_content, is_tool_call = (
|
||||
parse_chat_output(token_ids))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if is_tool_call:
|
||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||
# For now, only Responses API supports tool call for
|
||||
# gpt-oss.
|
||||
raise NotImplementedError(
|
||||
"Tool call in Chat Completion API is not supported "
|
||||
"for gpt-oss yet. Please use Responses API instead.")
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if is_tool_call else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
continue
|
||||
|
||||
enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True
|
||||
if self.reasoning_parser and enable_thinking:
|
||||
try:
|
||||
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
return self.create_error_response(str(e))
|
||||
# If the reasoning parser is enabled,
|
||||
# tool calls are extracted exclusively from the content.
|
||||
reasoning_content, content = (
|
||||
reasoning_parser.extract_reasoning_content(
|
||||
output.text, request=request))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
else:
|
||||
reasoning_content = None
|
||||
content = output.text
|
||||
|
||||
auto_tools_called = False
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
(not isinstance(request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam
|
||||
) and request.tool_choice != "required"):
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=content,
|
||||
))
|
||||
],
|
||||
)
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
|
||||
# the fields of FunctionDefinition are a superset of the
|
||||
# tool call outputs and can be used for parsing
|
||||
assert content is not None
|
||||
tool_calls = TypeAdapter(
|
||||
list[FunctionDefinition]).validate_json(content)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters,
|
||||
ensure_ascii=False)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
|
||||
# handle when there are tools and tool choice is auto
|
||||
elif request.tools and (
|
||||
request.tool_choice == "auto"
|
||||
or request.tool_choice is None) and self.enable_auto_tools \
|
||||
and self.tool_parser:
|
||||
|
||||
try:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
content if content is not None else "", request=request)
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls)
|
||||
|
||||
else:
|
||||
# FOR NOW make it a chat message; we will have to detect
|
||||
# the type to make it later.
|
||||
ret_content = content
|
||||
|
||||
# try to use content return from tool parser first,
|
||||
# tool parser may do some modify for the content.
|
||||
if (tool_call_info.content
|
||||
and len(tool_call_info.content) > 0):
|
||||
ret_content = tool_call_info.content
|
||||
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=ret_content)
|
||||
|
||||
# undetermined case that is still important to handle
|
||||
else:
|
||||
logger.error(
|
||||
"Error in chat_completion_full_generator - cannot determine"
|
||||
" if tools should be extracted. Returning a standard chat "
|
||||
"completion.")
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content: Union[str, list[dict[str, str]]] = ""
|
||||
if (conversation and "content" in conversation[-1]
|
||||
and conversation[-1].get("role") == role):
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
if isinstance(last_msg_content, list):
|
||||
last_msg_content = "\n".join(msg['text']
|
||||
for msg in last_msg_content)
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + (choice.message.content
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
num_generated_tokens)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Log complete response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
for choice in choices:
|
||||
output_text = ""
|
||||
if choice.message.content:
|
||||
output_text = choice.message.content
|
||||
elif choice.message.tool_calls:
|
||||
# For tool calls, log the function name and arguments
|
||||
tool_call_descriptions = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if hasattr(tool_call.function, "name") and hasattr(
|
||||
tool_call.function, "arguments"):
|
||||
tool_call_descriptions.append(
|
||||
f"{tool_call.function.name}({tool_call.function.arguments})"
|
||||
)
|
||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||
|
||||
if output_text:
|
||||
# Get the corresponding output token IDs
|
||||
output_token_ids = None
|
||||
if choice.index < len(final_res.outputs):
|
||||
output_token_ids = final_res.outputs[
|
||||
choice.index].token_ids
|
||||
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=output_text,
|
||||
output_token_ids=output_token_ids,
|
||||
finish_reason=choice.finish_reason,
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,912 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
from functools import partial
|
||||
from importlib.resources import contents
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
from enum import Enum
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class StreamState(str, Enum):
|
||||
"""State machine states for XML to JSON streaming conversion."""
|
||||
|
||||
INIT = "INIT"
|
||||
BETWEEN = "BETWEEN"
|
||||
IN_KEY = "IN_KEY"
|
||||
WAITING_VALUE = "WAITING_VALUE"
|
||||
IN_VALUE = "IN_VALUE"
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{random_uuid()}"
|
||||
|
||||
def get_argument_type(
|
||||
func_name: str, arg_key: str, defined_tools: list[ChatCompletionToolsParam]
|
||||
) -> Optional[str]:
|
||||
"""Get the expected type of a function argument from tool definitions.
|
||||
|
||||
Supports complex JSON Schema definitions including:
|
||||
- Direct type field (including type arrays)
|
||||
- anyOf/oneOf: parameter can be any of multiple types
|
||||
- enum: parameter must be one of enum values
|
||||
- allOf: parameter must satisfy all type definitions
|
||||
- properties: inferred as object type
|
||||
- items: inferred as array type
|
||||
|
||||
Args:
|
||||
func_name: Name of the function/tool
|
||||
arg_key: Name of the argument
|
||||
defined_tools: List of available tools
|
||||
|
||||
Returns:
|
||||
The type string (e.g., 'string', 'number', 'object') or None if not found
|
||||
"""
|
||||
name2tool = {tool.function.name: tool for tool in defined_tools}
|
||||
|
||||
# Check if function exists
|
||||
tool = name2tool.get(func_name)
|
||||
if not tool:
|
||||
return None
|
||||
|
||||
# Get parameters safely using getattr
|
||||
params = getattr(tool.function, "parameters", None)
|
||||
if not isinstance(params, dict):
|
||||
return None
|
||||
|
||||
# Navigate to the type using dict.get() for safe access
|
||||
properties = params.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
return None
|
||||
|
||||
arg_spec = properties.get(arg_key)
|
||||
if isinstance(arg_spec, dict):
|
||||
# Use the new type inference function for complex JSON Schema support
|
||||
return infer_type_from_json_schema(arg_spec)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_number(value: str) -> Any:
|
||||
"""Convert string to appropriate number type (int or float).
|
||||
|
||||
Args:
|
||||
value: String value to convert
|
||||
|
||||
Returns:
|
||||
Converted number or original string if conversion fails
|
||||
"""
|
||||
try:
|
||||
if "." in value or "e" in value.lower():
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except (ValueError, AttributeError):
|
||||
return value
|
||||
|
||||
|
||||
def parse_arguments(
|
||||
json_value: str, arg_type: Optional[str] = None
|
||||
) -> tuple[Any, bool]:
|
||||
"""Parse argument value with multiple fallback strategies.
|
||||
|
||||
Args:
|
||||
json_value: Raw string value to parse
|
||||
arg_type: Expected type hint ('string', 'number', 'object', etc.)
|
||||
|
||||
Returns:
|
||||
Tuple of (parsed_value, is_valid_json)
|
||||
"""
|
||||
# Strategy 1: Direct JSON parsing
|
||||
try:
|
||||
parsed_value = json.loads(json_value)
|
||||
|
||||
# Type coercion for number type
|
||||
if arg_type == "number" and isinstance(parsed_value, str):
|
||||
parsed_value = _convert_to_number(parsed_value)
|
||||
|
||||
return parsed_value, True
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Strategy 2: Unescape and parse
|
||||
try:
|
||||
wrapped = json.loads('{"tmp": "' + json_value + '"}')
|
||||
parsed_value = json.loads(wrapped["tmp"])
|
||||
|
||||
if arg_type == "number" and isinstance(parsed_value, str):
|
||||
parsed_value = _convert_to_number(parsed_value)
|
||||
|
||||
return parsed_value, True
|
||||
except (json.JSONDecodeError, ValueError, KeyError):
|
||||
pass
|
||||
|
||||
# Strategy 3: ast.literal_eval
|
||||
try:
|
||||
parsed_value = ast.literal_eval(json_value)
|
||||
return parsed_value, True
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
# Strategy 4: Treat as string
|
||||
try:
|
||||
quoted_value = json.dumps(str(json_value))
|
||||
return json.loads(quoted_value), True
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return json_value, False
|
||||
|
||||
def infer_type_from_json_schema(schema: dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Infer the primary type of a parameter from JSON Schema.
|
||||
|
||||
Supports complex JSON Schema structures including:
|
||||
- Direct type field (including type arrays)
|
||||
- anyOf/oneOf: parameter can be any of multiple types
|
||||
- enum: parameter must be one of enum values
|
||||
- allOf: parameter must satisfy all type definitions
|
||||
- properties: inferred as object type
|
||||
- items: inferred as array type
|
||||
|
||||
Args:
|
||||
schema: JSON Schema definition
|
||||
|
||||
Returns:
|
||||
Inferred type ('string', 'number', 'object', 'array', etc.) or None
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return None
|
||||
|
||||
# Priority 1: Direct type field (including type arrays)
|
||||
if "type" in schema:
|
||||
type_value = schema["type"]
|
||||
if isinstance(type_value, str):
|
||||
return type_value
|
||||
elif isinstance(type_value, list) and type_value:
|
||||
# Handle type arrays: return first non-null type
|
||||
non_null_types = [t for t in type_value if t != "null"]
|
||||
if non_null_types:
|
||||
return non_null_types[0]
|
||||
return "string" # If only null, default to string
|
||||
|
||||
# Priority 2: Handle anyOf/oneOf
|
||||
if "anyOf" in schema or "oneOf" in schema:
|
||||
schemas = schema.get("anyOf") or schema.get("oneOf")
|
||||
types = []
|
||||
|
||||
if isinstance(schemas, list):
|
||||
for sub_schema in schemas:
|
||||
inferred_type = infer_type_from_json_schema(sub_schema)
|
||||
if inferred_type:
|
||||
types.append(inferred_type)
|
||||
|
||||
if types:
|
||||
# If all types are the same, return unified type
|
||||
if len(set(types)) == 1:
|
||||
return types[0]
|
||||
# When types differ, prioritize string (safest)
|
||||
if "string" in types:
|
||||
return "string"
|
||||
# Otherwise return first type
|
||||
return types[0]
|
||||
|
||||
# Priority 3: Handle enum (infer type from enum values)
|
||||
if "enum" in schema and isinstance(schema["enum"], list):
|
||||
if not schema["enum"]:
|
||||
return "string"
|
||||
|
||||
# Infer type from enum values
|
||||
enum_types = set()
|
||||
for value in schema["enum"]:
|
||||
if value is None:
|
||||
enum_types.add("null")
|
||||
elif isinstance(value, bool):
|
||||
enum_types.add("boolean")
|
||||
elif isinstance(value, int):
|
||||
enum_types.add("integer")
|
||||
elif isinstance(value, float):
|
||||
enum_types.add("number")
|
||||
elif isinstance(value, str):
|
||||
enum_types.add("string")
|
||||
elif isinstance(value, list):
|
||||
enum_types.add("array")
|
||||
elif isinstance(value, dict):
|
||||
enum_types.add("object")
|
||||
|
||||
# If type is uniform, return that type
|
||||
if len(enum_types) == 1:
|
||||
return enum_types.pop()
|
||||
# Mixed types, prioritize string
|
||||
return "string"
|
||||
|
||||
# Priority 4: Handle allOf (must satisfy all types)
|
||||
if "allOf" in schema and isinstance(schema["allOf"], list):
|
||||
schemas = schema["allOf"]
|
||||
for sub_schema in schemas:
|
||||
inferred_type = infer_type_from_json_schema(sub_schema)
|
||||
if inferred_type and inferred_type != "string":
|
||||
return inferred_type
|
||||
return "string"
|
||||
|
||||
# Priority 5: Infer object type
|
||||
if "properties" in schema:
|
||||
return "object"
|
||||
|
||||
# Priority 6: Infer array type
|
||||
if "items" in schema:
|
||||
return "array"
|
||||
|
||||
return None
|
||||
|
||||
@ToolParserManager.register_module("glm47")
|
||||
class Glm47MoeModelToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
self._tool_indices = 0
|
||||
self._last_arguments: str = ""
|
||||
self._streamed_raw_length = 0
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>",
|
||||
re.DOTALL)
|
||||
self.func_detail_regex = re.compile(
|
||||
r"<tool_call>([^\n<]*)\n?(.*)</tool_call>", re.DOTALL)
|
||||
self.func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL)
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self._buffer = ""
|
||||
self._reset_streaming_state()
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset the streaming state machine for a new tool call."""
|
||||
self._stream_state = StreamState.INIT
|
||||
self._current_key = ""
|
||||
self._current_value = ""
|
||||
self._xml_tag_buffer = ""
|
||||
self._is_first_param = True
|
||||
self._value_started = False
|
||||
self._cached_value_type: Optional[str] = (
|
||||
None # Cache the value type for consistency
|
||||
)
|
||||
self._tool_call_completed = False # Reset tool call completion status
|
||||
self._sent_empty_object = False # Reset empty object sent status
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str, arg_name: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if tool.function.name == tool_name:
|
||||
if tool.function.parameters is None:
|
||||
return False
|
||||
arg_type = tool.function.parameters.get(
|
||||
"properties", {}).get(arg_name, {}).get("type", None)
|
||||
return arg_type == "string"
|
||||
logger.warning("No tool named '%s'.", tool_name)
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
matched_tool_calls = self.func_call_regex.findall(model_output)
|
||||
logger.debug("model_output: %s", model_output)
|
||||
try:
|
||||
tool_calls = []
|
||||
for match in matched_tool_calls:
|
||||
tc_detail = self.func_detail_regex.search(match)
|
||||
tc_name = tc_detail.group(1)
|
||||
tc_args = tc_detail.group(2)
|
||||
pairs = self.func_arg_regex.findall(tc_args)
|
||||
arg_dct = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = value.strip()
|
||||
if not _is_string_type(tc_name, arg_key, request.tools):
|
||||
arg_val = _deserialize(arg_val)
|
||||
logger.debug("arg_key = %s, arg_val = %s", arg_key,
|
||||
arg_val)
|
||||
arg_dct[arg_key] = arg_val
|
||||
tool_calls.append(
|
||||
ToolCall(type="function",
|
||||
function=FunctionCall(
|
||||
name=tc_name, arguments=json.dumps(arg_dct))))
|
||||
except Exception:
|
||||
logger.exception("Failed to extract tool call spec")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
else:
|
||||
if len(tool_calls) > 0:
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]:
|
||||
"""Extract function name, arguments and end marker from regex match.
|
||||
|
||||
Args:
|
||||
match: Regex match object
|
||||
|
||||
Returns:
|
||||
(func_name, func_args_raw, is_tool_end)
|
||||
"""
|
||||
func_name = match.group(1).strip()
|
||||
func_args_raw = match.group(2).strip() if match.group(2) else ""
|
||||
is_tool_end = match.group(3) or ""
|
||||
return func_name, func_args_raw, is_tool_end
|
||||
|
||||
def _send_tool_name_if_needed(
|
||||
self, func_name: str, has_arg_key: bool, is_tool_end: str
|
||||
) -> Optional[DeltaToolCall]:
|
||||
"""Send tool name if needed.
|
||||
|
||||
Args:
|
||||
func_name: Function name
|
||||
has_arg_key: Whether current text contains <arg_key
|
||||
is_tool_end: Whether end marker is encountered
|
||||
|
||||
Returns:
|
||||
Tool call item or None
|
||||
"""
|
||||
if self.current_tool_name_sent:
|
||||
return None
|
||||
|
||||
# Function name completeness check
|
||||
is_func_name_complete = has_arg_key or is_tool_end == self.tool_call_end_token
|
||||
|
||||
if not is_func_name_complete:
|
||||
return None
|
||||
|
||||
if not func_name:
|
||||
logger.warning("Empty function name detected, skipping tool call")
|
||||
return None
|
||||
|
||||
# Send tool name
|
||||
self.current_tool_name_sent = True
|
||||
self._streamed_raw_length = 0
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Record tool info
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
|
||||
return DeltaToolCall(
|
||||
id=random_tool_call_id(),
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name=func_name, arguments=""),
|
||||
)
|
||||
|
||||
def _parse_argument_pairs(
|
||||
self, pairs: list[tuple[str, str]], func_name: str, tools: list[ChatCompletionToolsParam]
|
||||
) -> dict[str, Any]:
|
||||
"""Parse argument key-value pairs with type coercion.
|
||||
|
||||
Args:
|
||||
pairs: List of (key, value) tuples from regex matching
|
||||
func_name: Name of the function
|
||||
tools: List of available tools
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed arguments
|
||||
"""
|
||||
arguments = {}
|
||||
for arg_key, arg_value in pairs:
|
||||
arg_key = arg_key.strip()
|
||||
arg_value = arg_value.strip()
|
||||
arg_type = get_argument_type(func_name, arg_key, tools)
|
||||
parsed_value, is_good_json = parse_arguments(arg_value, arg_type)
|
||||
|
||||
if arg_type == "string":
|
||||
# Only convert to string if explicitly defined as string type
|
||||
if isinstance(parsed_value, str):
|
||||
arguments[arg_key] = parsed_value
|
||||
elif isinstance(parsed_value, (dict, list)):
|
||||
# If parsed as dict/list but schema says string, convert to JSON string
|
||||
arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)
|
||||
else:
|
||||
arguments[arg_key] = str(parsed_value)
|
||||
elif arg_type is None:
|
||||
# If type is not defined, keep the parsed value as-is
|
||||
arguments[arg_key] = parsed_value if is_good_json else arg_value
|
||||
else:
|
||||
# For other types (number, object, array, etc.), use parsed value
|
||||
arguments[arg_key] = parsed_value if is_good_json else arg_value
|
||||
|
||||
return arguments
|
||||
|
||||
def _finalize_tool_call(
|
||||
self,
|
||||
func_name: str,
|
||||
func_args_raw: str,
|
||||
tools: list[ChatCompletionToolsParam],
|
||||
match_end_pos: int,
|
||||
current_text: str,
|
||||
) -> list[DeltaToolCall]:
|
||||
"""Complete tool call processing.
|
||||
|
||||
Args:
|
||||
func_name: Function name
|
||||
func_args_raw: Raw argument string
|
||||
tools: List of available tools
|
||||
match_end_pos: Match end position
|
||||
current_text: Current text
|
||||
|
||||
Returns:
|
||||
List of tool call items to add
|
||||
"""
|
||||
calls = []
|
||||
|
||||
# Handle no-arg function or need to close braces
|
||||
if self._is_first_param and not self._sent_empty_object:
|
||||
# No-arg function
|
||||
calls.append(
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(name=None, arguments="{}"),
|
||||
)
|
||||
)
|
||||
self._last_arguments += "{}"
|
||||
self.streamed_args_for_tool[self.current_tool_id] += "{}"
|
||||
self._sent_empty_object = True
|
||||
elif not self._last_arguments.endswith("}") and not self._sent_empty_object:
|
||||
# Need to close brace
|
||||
calls.append(
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(name=None, arguments="}"),
|
||||
)
|
||||
)
|
||||
self._last_arguments += "}"
|
||||
self.streamed_args_for_tool[self.current_tool_id] += "}"
|
||||
self._sent_empty_object = True
|
||||
|
||||
# Parse final arguments
|
||||
if func_args_raw:
|
||||
try:
|
||||
pairs = self.func_arg_regex.findall(func_args_raw)
|
||||
if pairs:
|
||||
arguments = self._parse_argument_pairs(pairs, func_name, tools)
|
||||
self.prev_tool_call_arr[self.current_tool_id][
|
||||
"arguments"
|
||||
] = arguments
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse arguments: {e}", exc_info=True)
|
||||
|
||||
# Clean buffer
|
||||
self._buffer = current_text[match_end_pos:]
|
||||
|
||||
# Reset state for next tool call
|
||||
self._tool_call_completed = True
|
||||
self.current_tool_id += 1
|
||||
self._last_arguments = ""
|
||||
self.current_tool_name_sent = False
|
||||
self._streamed_raw_length = 0
|
||||
self._reset_streaming_state()
|
||||
|
||||
return calls
|
||||
|
||||
def _format_value_complete(self, value: str, value_type: str) -> str:
|
||||
"""Format complete value based on type.
|
||||
|
||||
Args:
|
||||
value: Raw value string
|
||||
value_type: Expected type ('string', 'number', 'object')
|
||||
|
||||
Returns:
|
||||
Properly formatted JSON value string
|
||||
"""
|
||||
if value_type == "string":
|
||||
# Ensure proper JSON string formatting with quotes
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
elif value_type == "number":
|
||||
try:
|
||||
num = _convert_to_number(value.strip() if value else "")
|
||||
return str(num)
|
||||
except (ValueError, AttributeError):
|
||||
# Fallback to string if not a valid number
|
||||
logger.warning(
|
||||
f"Failed to parse '{value}' as number, treating as string"
|
||||
)
|
||||
return json.dumps(str(value) if value else "", ensure_ascii=False)
|
||||
else:
|
||||
# For object/array types, return as-is (should already be valid JSON)
|
||||
return value
|
||||
|
||||
|
||||
def _process_xml_to_json_streaming(
|
||||
self, raw_increment: str, func_name: str, tools: list[ChatCompletionToolsParam]
|
||||
) -> str:
|
||||
"""Convert XML increment to JSON streaming output using state machine.
|
||||
|
||||
This method processes XML fragments character by character and converts them
|
||||
to JSON format incrementally. It maintains state across calls to handle
|
||||
partial XML tags and values.
|
||||
|
||||
Args:
|
||||
raw_increment: New XML content to process
|
||||
func_name: Name of the function being called
|
||||
tools: List of available tools for type inference
|
||||
|
||||
Returns:
|
||||
JSON string increment to append to the output
|
||||
"""
|
||||
json_output = ""
|
||||
|
||||
for char in raw_increment:
|
||||
self._xml_tag_buffer += char
|
||||
|
||||
if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:
|
||||
if self._xml_tag_buffer.endswith("<arg_key>"):
|
||||
self._stream_state = StreamState.IN_KEY
|
||||
self._current_key = ""
|
||||
self._xml_tag_buffer = ""
|
||||
json_output += "{" if self._is_first_param else ", "
|
||||
self._is_first_param = False
|
||||
|
||||
elif self._stream_state == StreamState.IN_KEY:
|
||||
if self._xml_tag_buffer.endswith("</arg_key>"):
|
||||
self._current_key = self._xml_tag_buffer[:-10].strip()
|
||||
self._xml_tag_buffer = ""
|
||||
self._stream_state = StreamState.WAITING_VALUE
|
||||
json_output += (
|
||||
json.dumps(self._current_key, ensure_ascii=False) + ": "
|
||||
)
|
||||
|
||||
elif self._stream_state == StreamState.WAITING_VALUE:
|
||||
if self._xml_tag_buffer.endswith("<arg_value>"):
|
||||
self._stream_state = StreamState.IN_VALUE
|
||||
self._current_value = ""
|
||||
self._xml_tag_buffer = ""
|
||||
self._value_started = False
|
||||
# Determine and cache the value type at the start
|
||||
self._cached_value_type = self._get_value_type(
|
||||
func_name, self._current_key, tools
|
||||
)
|
||||
|
||||
elif self._stream_state == StreamState.IN_VALUE:
|
||||
if self._xml_tag_buffer.endswith("</arg_value>"):
|
||||
final_value = self._xml_tag_buffer[:-12]
|
||||
self._current_value += final_value
|
||||
|
||||
# Use cached value type for consistency
|
||||
value_type = self._cached_value_type or "string"
|
||||
|
||||
if self._value_started:
|
||||
# Output any remaining content
|
||||
if final_value:
|
||||
if value_type == "string":
|
||||
json_output += json.dumps(
|
||||
final_value, ensure_ascii=False
|
||||
)[1:-1]
|
||||
else:
|
||||
json_output += final_value
|
||||
# Always output closing quote for string type when value was started
|
||||
if value_type == "string":
|
||||
json_output += '"'
|
||||
else:
|
||||
# Value was never started (empty or complete in one chunk)
|
||||
json_output += self._format_value_complete(
|
||||
self._current_value, value_type
|
||||
)
|
||||
|
||||
self._xml_tag_buffer = ""
|
||||
self._stream_state = StreamState.BETWEEN
|
||||
self._current_value = ""
|
||||
self._value_started = False
|
||||
self._cached_value_type = None # Reset cached type
|
||||
else:
|
||||
closing_tag = "</arg_value>"
|
||||
is_potential_closing = len(self._xml_tag_buffer) <= len(
|
||||
closing_tag
|
||||
) and closing_tag.startswith(self._xml_tag_buffer)
|
||||
|
||||
if not is_potential_closing:
|
||||
content = self._xml_tag_buffer
|
||||
# Use cached value type for consistency
|
||||
value_type = self._cached_value_type or "string"
|
||||
|
||||
if value_type == "string":
|
||||
if not self._value_started:
|
||||
json_output += '"'
|
||||
self._value_started = True
|
||||
if content:
|
||||
json_output += json.dumps(content, ensure_ascii=False)[
|
||||
1:-1
|
||||
]
|
||||
self._current_value += content
|
||||
self._xml_tag_buffer = ""
|
||||
elif value_type == "number":
|
||||
if content:
|
||||
if not self._value_started:
|
||||
self._value_started = True
|
||||
json_output += content
|
||||
self._current_value += content
|
||||
self._xml_tag_buffer = ""
|
||||
else:
|
||||
# For object/array types, output as-is
|
||||
if content:
|
||||
if not self._value_started:
|
||||
self._value_started = True
|
||||
json_output += content
|
||||
self._current_value += content
|
||||
self._xml_tag_buffer = ""
|
||||
|
||||
return json_output
|
||||
|
||||
def _get_value_type(self, func_name: str, key: str, tools: list[ChatCompletionToolsParam]) -> str:
|
||||
"""Get parameter type from tool definition, with fallback to auto-detection.
|
||||
|
||||
Args:
|
||||
func_name: Name of the function
|
||||
key: Parameter name
|
||||
tools: List of available tools
|
||||
|
||||
Returns:
|
||||
Type string: 'string', 'number', 'object', 'array', or 'boolean'
|
||||
"""
|
||||
arg_type = get_argument_type(func_name, key, tools)
|
||||
if arg_type:
|
||||
return arg_type
|
||||
|
||||
# Improved auto-detection type from value (best effort)
|
||||
value_content = self._current_value.strip() if self._current_value else ""
|
||||
|
||||
if not value_content:
|
||||
return "string"
|
||||
|
||||
# Try to parse as valid JSON first
|
||||
try:
|
||||
parsed = json.loads(value_content)
|
||||
if isinstance(parsed, dict):
|
||||
return "object"
|
||||
elif isinstance(parsed, list):
|
||||
return "array"
|
||||
elif isinstance(parsed, bool):
|
||||
return "boolean"
|
||||
elif isinstance(parsed, (int, float)):
|
||||
return "number"
|
||||
# For string values, check if they look like numbers
|
||||
elif isinstance(parsed, str):
|
||||
if parsed.isdigit() or (
|
||||
parsed.startswith("-") and parsed[1:].isdigit()
|
||||
):
|
||||
return "number"
|
||||
return "string"
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, try heuristic detection
|
||||
first_char = value_content[0] if value_content else ""
|
||||
|
||||
if first_char.isdigit() or first_char in ["-", "."]:
|
||||
return "number"
|
||||
elif first_char in ["{", "["]:
|
||||
return "object"
|
||||
elif first_char in ['"', "'"]:
|
||||
return "string"
|
||||
|
||||
# Default to string (safest fallback)
|
||||
return "string"
|
||||
|
||||
|
||||
def _process_arguments_streaming(
|
||||
self, func_name: str, func_args_raw: str, tools: list[ChatCompletionToolsParam]
|
||||
) -> Optional[DeltaToolCall]:
|
||||
"""Process streaming arguments.
|
||||
|
||||
Args:
|
||||
func_name: Function name
|
||||
func_args_raw: Raw argument string
|
||||
tools: List of available tools
|
||||
|
||||
Returns:
|
||||
Tool call item with parameter updates or None
|
||||
"""
|
||||
current_raw_length = len(func_args_raw)
|
||||
|
||||
if current_raw_length <= self._streamed_raw_length:
|
||||
return None
|
||||
|
||||
# Get new raw XML content
|
||||
raw_increment = func_args_raw[self._streamed_raw_length :]
|
||||
|
||||
# Convert XML to JSON using state machine
|
||||
json_increment = self._process_xml_to_json_streaming(
|
||||
raw_increment, func_name, tools
|
||||
)
|
||||
|
||||
# CRITICAL: Update streamed length BEFORE early return
|
||||
# Even if json_increment is empty, the input has been consumed by the state machine
|
||||
self._streamed_raw_length = current_raw_length
|
||||
|
||||
if not json_increment:
|
||||
return None
|
||||
|
||||
# Update state
|
||||
self._last_arguments += json_increment
|
||||
self.streamed_args_for_tool[self.current_tool_id] += json_increment
|
||||
|
||||
return DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(name=None, arguments=json_increment),
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
self._buffer += delta_text
|
||||
current_text = self._buffer
|
||||
# Check if we have a tool call
|
||||
has_tool_call = self.tool_call_start_token in current_text
|
||||
|
||||
if not has_tool_call:
|
||||
# Check if buffer could be the start of a tool call
|
||||
# Keep buffer if it could be a partial match of bot_token
|
||||
is_potential_start = any(
|
||||
self.tool_call_start_token.startswith(current_text[-i:])
|
||||
for i in range(1, min(len(current_text), len(self.tool_call_start_token)) + 1)
|
||||
)
|
||||
|
||||
if not is_potential_start:
|
||||
# Not a potential tool call, return as normal text
|
||||
# Must return the entire buffer (current_text), not just new_text,
|
||||
# because buffer may contain previously accumulated characters like '<'
|
||||
# that turned out not to be part of a tool call
|
||||
output_text = current_text
|
||||
self._buffer = ""
|
||||
if self.tool_call_end_token in output_text:
|
||||
output_text = output_text.replace(self.tool_call_end_token, "")
|
||||
return DeltaMessage(content=output_text)
|
||||
else:
|
||||
# Could be start of tool call, keep buffering
|
||||
return None
|
||||
|
||||
# Extract any text before the first bot_token and return it as normal_text
|
||||
output_text = ""
|
||||
first_bot_token_idx = current_text.find(self.tool_call_start_token)
|
||||
if first_bot_token_idx > 0:
|
||||
output_text= current_text[:first_bot_token_idx]
|
||||
current_text = current_text[first_bot_token_idx:]
|
||||
# Update buffer to only include from the bot token onwards
|
||||
self._buffer = current_text
|
||||
if not hasattr(self, "_tool_indices"):
|
||||
self._tool_indices += 1
|
||||
|
||||
calls: list[DeltaToolCall] = []
|
||||
try:
|
||||
# Try to match a partial or complete tool call
|
||||
# Use a single flexible regex pattern that handles all cases
|
||||
partial_match = re.search(
|
||||
r"<tool_call>(.*?)(?:(<arg_key.*?))?(?:(</tool_call>)|$)",
|
||||
current_text,
|
||||
re.DOTALL,
|
||||
)
|
||||
if not partial_match:
|
||||
return None
|
||||
# return DeltaMessage(content=output_text, tool_calls=[])
|
||||
|
||||
# Extract match groups using helper method
|
||||
func_name, func_args_raw, is_tool_end = self._extract_match_groups(
|
||||
match=partial_match
|
||||
)
|
||||
|
||||
# Initialize tool call state if needed (keeping existing logic)
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = [""]
|
||||
self._streamed_raw_length = 0
|
||||
self.current_tool_name_sent = False # Reset for new tool call
|
||||
self._reset_streaming_state()
|
||||
# Check if this is a continuation of an existing tool call or a new one
|
||||
elif not self.current_tool_name_sent:
|
||||
# Only increment tool_id if we're truly starting a NEW tool call
|
||||
# Don't increment if this is just the first time we're processing
|
||||
# a tool call that was received in the buffer
|
||||
# The key insight: only increment when we've COMPLETED a previous tool call
|
||||
# and now see another bot_token in new_text
|
||||
pass # Remove the problematic auto-increment logic
|
||||
|
||||
# Ensure tracking arrays are large enough (keeping existing logic)
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Determine if function name is complete by checking for <arg_key> in the full text
|
||||
# This is important for streaming scenarios where args come in later chunks
|
||||
has_arg_key = "<arg_key" in current_text
|
||||
# Send tool name if needed
|
||||
tool_name_item = self._send_tool_name_if_needed(
|
||||
func_name, has_arg_key, is_tool_end
|
||||
)
|
||||
|
||||
if tool_name_item:
|
||||
calls.append(tool_name_item)
|
||||
# Process streaming arguments if tool name has been sent
|
||||
if self.current_tool_name_sent and request.tools:
|
||||
arg_item = self._process_arguments_streaming(
|
||||
func_name, func_args_raw, request.tools
|
||||
)
|
||||
if arg_item:
|
||||
calls.append(arg_item)
|
||||
# Finalize tool call if end token is encountered
|
||||
if is_tool_end == self.tool_call_end_token and not self._tool_call_completed:
|
||||
finalize_calls = self._finalize_tool_call(
|
||||
func_name,
|
||||
func_args_raw,
|
||||
request.tools,
|
||||
partial_match.end(),
|
||||
current_text,
|
||||
)
|
||||
calls.extend(finalize_calls)
|
||||
return DeltaMessage(content=output_text, tool_calls=calls)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True)
|
||||
return DeltaMessage(content=output_text)
|
||||
|
||||
# Only return if we have meaningful content or tool calls to avoid empty chunks
|
||||
if output_text.strip() or calls:
|
||||
return DeltaMessage(content=output_text, tool_calls=calls)
|
||||
else:
|
||||
return None
|
||||
Reference in New Issue
Block a user