From 726cefb7a39448eb3d6bbff80cb21586f01da159 Mon Sep 17 00:00:00 2001 From: astrophel0 <128400279+astrophel0@users.noreply.github.com> Date: Fri, 30 Jan 2026 15:24:14 +0800 Subject: [PATCH] [dev]add glm4.7 tool-parser (#151) Signed-off-by: zhangzhenyi Co-authored-by: Li Wei --- .../entrypoints/openai/serving_chat.py | 948 ++++++++++++++++++ .../tool_parsers/glm47_moe_tool_parser.py | 912 +++++++++++++++++ 2 files changed, 1860 insertions(+) create mode 100644 vllm_kunlun/entrypoints/openai/serving_chat.py create mode 100644 vllm_kunlun/entrypoints/openai/tool_parsers/glm47_moe_tool_parser.py diff --git a/vllm_kunlun/entrypoints/openai/serving_chat.py b/vllm_kunlun/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000..af03fa2 --- /dev/null +++ b/vllm_kunlun/entrypoints/openai/serving_chat.py @@ -0,0 +1,948 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Callable, Final, Optional, Union + +import jinja2 +import partial_json_parser +import regex as re +from fastapi import Request +from openai_harmony import Message as OpenAIMessage +from pydantic import TypeAdapter + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage, + random_tool_call_id) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, get_system_message, parse_chat_input, + parse_chat_output, render_for_completion) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, + PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) +from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params) +from vllm.utils import as_list + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + enable_force_include_usage: bool, + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() + for _ in range(num_choices) + ] + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[list[list[int]]] + function_name_returned = [False] * num_choices + + # Always track previous_texts for comprehensive output logging + previous_texts = [""] * num_choices + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or self.reasoning_parser: + # These are only required in "auto" tool choice case + all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices + elif request.tool_choice == "required": + all_previous_token_ids = None + else: + all_previous_token_ids = None + + enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True + + try: + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: list[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage \ + or enable_force_include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + num_cached_tokens = res.num_cached_tokens + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + return_as_token_id=request. + return_tokens_as_token_ids, + ) + else: + logprobs = None + + if self.use_harmony: + harmony_parser = harmony_parsers[i] + for token_id in output.token_ids: + harmony_parser.process(token_id) + # FIXME(woosuk): Support function calling + is_final = harmony_parser.current_channel == "final" + if not (request.include_reasoning or is_final): + # Skip the reasoning content. + continue + delta_text = harmony_parser.last_content_delta or "" + else: + delta_text = output.text + + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: Optional[DeltaMessage] + + # just update previous_texts and previous_token_ids + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + + # avoid the None + list error. + if previous_token_ids: + current_token_ids = previous_token_ids + as_list( + output.token_ids) + else: + current_token_ids = as_list(output.token_ids) + + if self.use_harmony: + if is_final: + delta_message = DeltaMessage(content=delta_text) + else: + delta_message = DeltaMessage( + reasoning_content=delta_text) + # handle streaming deltas for tools with named tool_choice + elif tool_choice_function_name: + if (self.reasoning_parser and not reasoning_end_arr[i] + and not reasoning_parser.is_reasoning_end( + previous_token_ids)): + assert reasoning_parser is not None + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # When encountering think end id in delta_token_ids + # or think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Only keep 'content', remove 'reasoning_content'. + if reasoning_parser.is_reasoning_end( + as_list(output.token_ids)) or ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids)): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.reasoning_parser: + delta_text = previous_text + delta_text + current_text = "" + + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall( + arguments=delta_text), + index=i) + else: + delta_tool_call = DeltaToolCall( + id=random_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + function_name_returned[i] = True + + delta_message = DeltaMessage(tool_calls=[ + delta_tool_call, + ]) + + elif request.tool_choice == "required": + assert previous_texts is not None + previous_text = previous_texts[i] + current_text = previous_text + delta_text + fn_name_returned = function_name_returned[i] + + if self.reasoning_parser: + _, content = \ + reasoning_parser.extract_reasoning_content( + current_text, + request + ) + else: + content = current_text + delta_message, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=fn_name_returned)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.reasoning_parser: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + if not reasoning_end_arr[i]: + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + )) + # When encountering think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if not enable_thinking: + reasoning_end_arr[i] = True + current_token_ids = output_token_ids + if delta_message and delta_message.reasoning_content: + current_text = delta_message.reasoning_content + delta_message.content = None + delta_message.reasoning_content = None + else: + current_text = delta_message.content + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if reasoning_parser.is_reasoning_end( + output_token_ids): + reasoning_end_arr[i] = True + current_token_ids = \ + reasoning_parser.extract_content_ids( + output_token_ids) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = output_token_ids + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request)) + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + + # when only reasoning + elif self.reasoning_parser and enable_thinking: + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # update the previous values for the next iteration + if tool_choice_auto or self.reasoning_parser: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + else: + # Update for comprehensive logging even in simple case + assert previous_texts is not None + previous_texts[i] += delta_text + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + # Log streaming delta if output logging is enabled + if self.enable_log_outputs and self.request_logger: + delta_content = "" + if delta_message.content: + delta_content = delta_message.content + elif delta_message.tool_calls: + delta_content = "".join( + tc.function.arguments + for tc in delta_message.tool_calls + if tc.function and tc.function.arguments) + + if delta_content: + self.request_logger.log_outputs( + request_id=request_id, + outputs=delta_content, + output_token_ids=as_list(output.token_ids), + finish_reason=output.finish_reason, + is_streaming=True, + delta=True, + ) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + latest_delta_len = 0 + if ((isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): + latest_delta_len = len( + delta_message.tool_calls[0].function. + arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {}), + ensure_ascii=False) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + if (latest_delta_len > 0): + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + + finish_reason_sent[i] = True + + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens, + ) + + # Log complete streaming response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + # Log the complete response for each choice + for i in range(num_choices): + full_text = ( + previous_texts[i] + if previous_texts and i < len(previous_texts) else + f"" + ) + self.request_logger.log_outputs( + request_id=request_id, + outputs=full_text, + output_token_ids= + None, # Consider also logging all token IDs + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + if self.use_harmony: + reasoning_content, final_content, is_tool_call = ( + parse_chat_output(token_ids)) + if not request.include_reasoning: + reasoning_content = None + + if is_tool_call: + # TODO(woosuk): Implement tool call for gpt-oss. + # For now, only Responses API supports tool call for + # gpt-oss. + raise NotImplementedError( + "Tool call in Chat Completion API is not supported " + "for gpt-oss yet. Please use Responses API instead.") + else: + # Normal message + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=final_content, + ) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if is_tool_call else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue + + enable_thinking: bool = request.chat_template_kwargs.get("enable_thinking", True) if request.chat_template_kwargs else True + if self.reasoning_parser and enable_thinking: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content( + output.text, request=request)) + if not request.include_reasoning: + reasoning_content = None + else: + reasoning_content = None + content = output.text + + auto_tools_called = False + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and \ + (not isinstance(request.tool_choice, + ChatCompletionNamedToolChoiceParam + ) and request.tool_choice != "required"): + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content, + )) + ], + ) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + assert content is not None + tool_calls = TypeAdapter( + list[FunctionDefinition]).validate_json(content) + message = ChatMessage( + role=role, + content="", + reasoning_content=reasoning_content, + tool_calls=[ + tool_call_class(function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, + ensure_ascii=False))) + for tool_call in tool_calls + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + ret_content = content + + # try to use content return from tool parser first, + # tool parser may do some modify for the content. + if (tool_call_info.content + and len(tool_call_info.content) > 0): + ret_content = tool_call_info.content + + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=ret_content) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if (conversation and "content" in conversation[-1] + and conversation[-1].get("role") == role): + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg['text'] + for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + ) + + # Log complete response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + for choice in choices: + output_text = "" + if choice.message.content: + output_text = choice.message.content + elif choice.message.tool_calls: + # For tool calls, log the function name and arguments + tool_call_descriptions = [] + for tool_call in choice.message.tool_calls: + if hasattr(tool_call.function, "name") and hasattr( + tool_call.function, "arguments"): + tool_call_descriptions.append( + f"{tool_call.function.name}({tool_call.function.arguments})" + ) + tool_calls_str = ", ".join(tool_call_descriptions) + output_text = f"[tool_calls: {tool_calls_str}]" + + if output_text: + # Get the corresponding output token IDs + output_token_ids = None + if choice.index < len(final_res.outputs): + output_token_ids = final_res.outputs[ + choice.index].token_ids + + self.request_logger.log_outputs( + request_id=request_id, + outputs=output_text, + output_token_ids=output_token_ids, + finish_reason=choice.finish_reason, + is_streaming=False, + delta=False, + ) + + return response \ No newline at end of file diff --git a/vllm_kunlun/entrypoints/openai/tool_parsers/glm47_moe_tool_parser.py b/vllm_kunlun/entrypoints/openai/tool_parsers/glm47_moe_tool_parser.py new file mode 100644 index 0000000..1053a77 --- /dev/null +++ b/vllm_kunlun/entrypoints/openai/tool_parsers/glm47_moe_tool_parser.py @@ -0,0 +1,912 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +from functools import partial +from importlib.resources import contents +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re +from enum import Enum +from vllm.utils import random_uuid + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + +class StreamState(str, Enum): + """State machine states for XML to JSON streaming conversion.""" + + INIT = "INIT" + BETWEEN = "BETWEEN" + IN_KEY = "IN_KEY" + WAITING_VALUE = "WAITING_VALUE" + IN_VALUE = "IN_VALUE" + +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{random_uuid()}" + +def get_argument_type( + func_name: str, arg_key: str, defined_tools: list[ChatCompletionToolsParam] +) -> Optional[str]: + """Get the expected type of a function argument from tool definitions. + + Supports complex JSON Schema definitions including: + - Direct type field (including type arrays) + - anyOf/oneOf: parameter can be any of multiple types + - enum: parameter must be one of enum values + - allOf: parameter must satisfy all type definitions + - properties: inferred as object type + - items: inferred as array type + + Args: + func_name: Name of the function/tool + arg_key: Name of the argument + defined_tools: List of available tools + + Returns: + The type string (e.g., 'string', 'number', 'object') or None if not found + """ + name2tool = {tool.function.name: tool for tool in defined_tools} + + # Check if function exists + tool = name2tool.get(func_name) + if not tool: + return None + + # Get parameters safely using getattr + params = getattr(tool.function, "parameters", None) + if not isinstance(params, dict): + return None + + # Navigate to the type using dict.get() for safe access + properties = params.get("properties") + if not isinstance(properties, dict): + return None + + arg_spec = properties.get(arg_key) + if isinstance(arg_spec, dict): + # Use the new type inference function for complex JSON Schema support + return infer_type_from_json_schema(arg_spec) + + return None + +def _convert_to_number(value: str) -> Any: + """Convert string to appropriate number type (int or float). + + Args: + value: String value to convert + + Returns: + Converted number or original string if conversion fails + """ + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except (ValueError, AttributeError): + return value + + +def parse_arguments( + json_value: str, arg_type: Optional[str] = None +) -> tuple[Any, bool]: + """Parse argument value with multiple fallback strategies. + + Args: + json_value: Raw string value to parse + arg_type: Expected type hint ('string', 'number', 'object', etc.) + + Returns: + Tuple of (parsed_value, is_valid_json) + """ + # Strategy 1: Direct JSON parsing + try: + parsed_value = json.loads(json_value) + + # Type coercion for number type + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 2: Unescape and parse + try: + wrapped = json.loads('{"tmp": "' + json_value + '"}') + parsed_value = json.loads(wrapped["tmp"]) + + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError, KeyError): + pass + + # Strategy 3: ast.literal_eval + try: + parsed_value = ast.literal_eval(json_value) + return parsed_value, True + except (ValueError, SyntaxError): + pass + + # Strategy 4: Treat as string + try: + quoted_value = json.dumps(str(json_value)) + return json.loads(quoted_value), True + except (json.JSONDecodeError, ValueError): + return json_value, False + +def infer_type_from_json_schema(schema: dict[str, Any]) -> Optional[str]: + """ + Infer the primary type of a parameter from JSON Schema. + + Supports complex JSON Schema structures including: + - Direct type field (including type arrays) + - anyOf/oneOf: parameter can be any of multiple types + - enum: parameter must be one of enum values + - allOf: parameter must satisfy all type definitions + - properties: inferred as object type + - items: inferred as array type + + Args: + schema: JSON Schema definition + + Returns: + Inferred type ('string', 'number', 'object', 'array', etc.) or None + """ + if not isinstance(schema, dict): + return None + + # Priority 1: Direct type field (including type arrays) + if "type" in schema: + type_value = schema["type"] + if isinstance(type_value, str): + return type_value + elif isinstance(type_value, list) and type_value: + # Handle type arrays: return first non-null type + non_null_types = [t for t in type_value if t != "null"] + if non_null_types: + return non_null_types[0] + return "string" # If only null, default to string + + # Priority 2: Handle anyOf/oneOf + if "anyOf" in schema or "oneOf" in schema: + schemas = schema.get("anyOf") or schema.get("oneOf") + types = [] + + if isinstance(schemas, list): + for sub_schema in schemas: + inferred_type = infer_type_from_json_schema(sub_schema) + if inferred_type: + types.append(inferred_type) + + if types: + # If all types are the same, return unified type + if len(set(types)) == 1: + return types[0] + # When types differ, prioritize string (safest) + if "string" in types: + return "string" + # Otherwise return first type + return types[0] + + # Priority 3: Handle enum (infer type from enum values) + if "enum" in schema and isinstance(schema["enum"], list): + if not schema["enum"]: + return "string" + + # Infer type from enum values + enum_types = set() + for value in schema["enum"]: + if value is None: + enum_types.add("null") + elif isinstance(value, bool): + enum_types.add("boolean") + elif isinstance(value, int): + enum_types.add("integer") + elif isinstance(value, float): + enum_types.add("number") + elif isinstance(value, str): + enum_types.add("string") + elif isinstance(value, list): + enum_types.add("array") + elif isinstance(value, dict): + enum_types.add("object") + + # If type is uniform, return that type + if len(enum_types) == 1: + return enum_types.pop() + # Mixed types, prioritize string + return "string" + + # Priority 4: Handle allOf (must satisfy all types) + if "allOf" in schema and isinstance(schema["allOf"], list): + schemas = schema["allOf"] + for sub_schema in schemas: + inferred_type = infer_type_from_json_schema(sub_schema) + if inferred_type and inferred_type != "string": + return inferred_type + return "string" + + # Priority 5: Infer object type + if "properties" in schema: + return "object" + + # Priority 6: Infer array type + if "items" in schema: + return "array" + + return None + +@ToolParserManager.register_module("glm47") +class Glm47MoeModelToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.tool_call_start_token = "" + self.tool_call_end_token = "" + self._tool_indices = 0 + self._last_arguments: str = "" + self._streamed_raw_length = 0 + + self.tool_calls_start_token = self.tool_call_start_token + + self.func_call_regex = re.compile(r".*?", + re.DOTALL) + self.func_detail_regex = re.compile( + r"([^\n<]*)\n?(.*)", re.DOTALL) + self.func_arg_regex = re.compile( + r"(.*?)\s*(.*?)", + re.DOTALL) + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self._buffer = "" + self._reset_streaming_state() + + def _reset_streaming_state(self) -> None: + """Reset the streaming state machine for a new tool call.""" + self._stream_state = StreamState.INIT + self._current_key = "" + self._current_value = "" + self._xml_tag_buffer = "" + self._is_first_param = True + self._value_started = False + self._cached_value_type: Optional[str] = ( + None # Cache the value type for consistency + ) + self._tool_call_completed = False # Reset tool call completion status + self._sent_empty_object = False # Reset empty object sent status + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + def _is_string_type( + tool_name: str, arg_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> bool: + if tools is None: + return False + for tool in tools: + if tool.function.name == tool_name: + if tool.function.parameters is None: + return False + arg_type = tool.function.parameters.get( + "properties", {}).get(arg_name, {}).get("type", None) + return arg_type == "string" + logger.warning("No tool named '%s'.", tool_name) + return False + + def _deserialize(value: str) -> Any: + try: + return json.loads(value) + except Exception: + pass + + try: + return ast.literal_eval(value) + except Exception: + pass + return value + + matched_tool_calls = self.func_call_regex.findall(model_output) + logger.debug("model_output: %s", model_output) + try: + tool_calls = [] + for match in matched_tool_calls: + tc_detail = self.func_detail_regex.search(match) + tc_name = tc_detail.group(1) + tc_args = tc_detail.group(2) + pairs = self.func_arg_regex.findall(tc_args) + arg_dct = {} + for key, value in pairs: + arg_key = key.strip() + arg_val = value.strip() + if not _is_string_type(tc_name, arg_key, request.tools): + arg_val = _deserialize(arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, + arg_val) + arg_dct[arg_key] = arg_val + tool_calls.append( + ToolCall(type="function", + function=FunctionCall( + name=tc_name, arguments=json.dumps(arg_dct)))) + except Exception: + logger.exception("Failed to extract tool call spec") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + else: + if len(tool_calls) > 0: + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=content) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]: + """Extract function name, arguments and end marker from regex match. + + Args: + match: Regex match object + + Returns: + (func_name, func_args_raw, is_tool_end) + """ + func_name = match.group(1).strip() + func_args_raw = match.group(2).strip() if match.group(2) else "" + is_tool_end = match.group(3) or "" + return func_name, func_args_raw, is_tool_end + + def _send_tool_name_if_needed( + self, func_name: str, has_arg_key: bool, is_tool_end: str + ) -> Optional[DeltaToolCall]: + """Send tool name if needed. + + Args: + func_name: Function name + has_arg_key: Whether current text contains dict[str, Any]: + """Parse argument key-value pairs with type coercion. + + Args: + pairs: List of (key, value) tuples from regex matching + func_name: Name of the function + tools: List of available tools + + Returns: + Dictionary of parsed arguments + """ + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + parsed_value, is_good_json = parse_arguments(arg_value, arg_type) + + if arg_type == "string": + # Only convert to string if explicitly defined as string type + if isinstance(parsed_value, str): + arguments[arg_key] = parsed_value + elif isinstance(parsed_value, (dict, list)): + # If parsed as dict/list but schema says string, convert to JSON string + arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False) + else: + arguments[arg_key] = str(parsed_value) + elif arg_type is None: + # If type is not defined, keep the parsed value as-is + arguments[arg_key] = parsed_value if is_good_json else arg_value + else: + # For other types (number, object, array, etc.), use parsed value + arguments[arg_key] = parsed_value if is_good_json else arg_value + + return arguments + + def _finalize_tool_call( + self, + func_name: str, + func_args_raw: str, + tools: list[ChatCompletionToolsParam], + match_end_pos: int, + current_text: str, + ) -> list[DeltaToolCall]: + """Complete tool call processing. + + Args: + func_name: Function name + func_args_raw: Raw argument string + tools: List of available tools + match_end_pos: Match end position + current_text: Current text + + Returns: + List of tool call items to add + """ + calls = [] + + # Handle no-arg function or need to close braces + if self._is_first_param and not self._sent_empty_object: + # No-arg function + calls.append( + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ) + self._last_arguments += "{}" + self.streamed_args_for_tool[self.current_tool_id] += "{}" + self._sent_empty_object = True + elif not self._last_arguments.endswith("}") and not self._sent_empty_object: + # Need to close brace + calls.append( + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ) + self._last_arguments += "}" + self.streamed_args_for_tool[self.current_tool_id] += "}" + self._sent_empty_object = True + + # Parse final arguments + if func_args_raw: + try: + pairs = self.func_arg_regex.findall(func_args_raw) + if pairs: + arguments = self._parse_argument_pairs(pairs, func_name, tools) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = arguments + except Exception as e: + logger.debug(f"Failed to parse arguments: {e}", exc_info=True) + + # Clean buffer + self._buffer = current_text[match_end_pos:] + + # Reset state for next tool call + self._tool_call_completed = True + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + + return calls + + def _format_value_complete(self, value: str, value_type: str) -> str: + """Format complete value based on type. + + Args: + value: Raw value string + value_type: Expected type ('string', 'number', 'object') + + Returns: + Properly formatted JSON value string + """ + if value_type == "string": + # Ensure proper JSON string formatting with quotes + return json.dumps(value, ensure_ascii=False) + elif value_type == "number": + try: + num = _convert_to_number(value.strip() if value else "") + return str(num) + except (ValueError, AttributeError): + # Fallback to string if not a valid number + logger.warning( + f"Failed to parse '{value}' as number, treating as string" + ) + return json.dumps(str(value) if value else "", ensure_ascii=False) + else: + # For object/array types, return as-is (should already be valid JSON) + return value + + + def _process_xml_to_json_streaming( + self, raw_increment: str, func_name: str, tools: list[ChatCompletionToolsParam] + ) -> str: + """Convert XML increment to JSON streaming output using state machine. + + This method processes XML fragments character by character and converts them + to JSON format incrementally. It maintains state across calls to handle + partial XML tags and values. + + Args: + raw_increment: New XML content to process + func_name: Name of the function being called + tools: List of available tools for type inference + + Returns: + JSON string increment to append to the output + """ + json_output = "" + + for char in raw_increment: + self._xml_tag_buffer += char + + if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_KEY + self._current_key = "" + self._xml_tag_buffer = "" + json_output += "{" if self._is_first_param else ", " + self._is_first_param = False + + elif self._stream_state == StreamState.IN_KEY: + if self._xml_tag_buffer.endswith(""): + self._current_key = self._xml_tag_buffer[:-10].strip() + self._xml_tag_buffer = "" + self._stream_state = StreamState.WAITING_VALUE + json_output += ( + json.dumps(self._current_key, ensure_ascii=False) + ": " + ) + + elif self._stream_state == StreamState.WAITING_VALUE: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_VALUE + self._current_value = "" + self._xml_tag_buffer = "" + self._value_started = False + # Determine and cache the value type at the start + self._cached_value_type = self._get_value_type( + func_name, self._current_key, tools + ) + + elif self._stream_state == StreamState.IN_VALUE: + if self._xml_tag_buffer.endswith(""): + final_value = self._xml_tag_buffer[:-12] + self._current_value += final_value + + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if self._value_started: + # Output any remaining content + if final_value: + if value_type == "string": + json_output += json.dumps( + final_value, ensure_ascii=False + )[1:-1] + else: + json_output += final_value + # Always output closing quote for string type when value was started + if value_type == "string": + json_output += '"' + else: + # Value was never started (empty or complete in one chunk) + json_output += self._format_value_complete( + self._current_value, value_type + ) + + self._xml_tag_buffer = "" + self._stream_state = StreamState.BETWEEN + self._current_value = "" + self._value_started = False + self._cached_value_type = None # Reset cached type + else: + closing_tag = "" + is_potential_closing = len(self._xml_tag_buffer) <= len( + closing_tag + ) and closing_tag.startswith(self._xml_tag_buffer) + + if not is_potential_closing: + content = self._xml_tag_buffer + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if value_type == "string": + if not self._value_started: + json_output += '"' + self._value_started = True + if content: + json_output += json.dumps(content, ensure_ascii=False)[ + 1:-1 + ] + self._current_value += content + self._xml_tag_buffer = "" + elif value_type == "number": + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + else: + # For object/array types, output as-is + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + + return json_output + + def _get_value_type(self, func_name: str, key: str, tools: list[ChatCompletionToolsParam]) -> str: + """Get parameter type from tool definition, with fallback to auto-detection. + + Args: + func_name: Name of the function + key: Parameter name + tools: List of available tools + + Returns: + Type string: 'string', 'number', 'object', 'array', or 'boolean' + """ + arg_type = get_argument_type(func_name, key, tools) + if arg_type: + return arg_type + + # Improved auto-detection type from value (best effort) + value_content = self._current_value.strip() if self._current_value else "" + + if not value_content: + return "string" + + # Try to parse as valid JSON first + try: + parsed = json.loads(value_content) + if isinstance(parsed, dict): + return "object" + elif isinstance(parsed, list): + return "array" + elif isinstance(parsed, bool): + return "boolean" + elif isinstance(parsed, (int, float)): + return "number" + # For string values, check if they look like numbers + elif isinstance(parsed, str): + if parsed.isdigit() or ( + parsed.startswith("-") and parsed[1:].isdigit() + ): + return "number" + return "string" + except json.JSONDecodeError: + # Not valid JSON, try heuristic detection + first_char = value_content[0] if value_content else "" + + if first_char.isdigit() or first_char in ["-", "."]: + return "number" + elif first_char in ["{", "["]: + return "object" + elif first_char in ['"', "'"]: + return "string" + + # Default to string (safest fallback) + return "string" + + + def _process_arguments_streaming( + self, func_name: str, func_args_raw: str, tools: list[ChatCompletionToolsParam] + ) -> Optional[DeltaToolCall]: + """Process streaming arguments. + + Args: + func_name: Function name + func_args_raw: Raw argument string + tools: List of available tools + + Returns: + Tool call item with parameter updates or None + """ + current_raw_length = len(func_args_raw) + + if current_raw_length <= self._streamed_raw_length: + return None + + # Get new raw XML content + raw_increment = func_args_raw[self._streamed_raw_length :] + + # Convert XML to JSON using state machine + json_increment = self._process_xml_to_json_streaming( + raw_increment, func_name, tools + ) + + # CRITICAL: Update streamed length BEFORE early return + # Even if json_increment is empty, the input has been consumed by the state machine + self._streamed_raw_length = current_raw_length + + if not json_increment: + return None + + # Update state + self._last_arguments += json_increment + self.streamed_args_for_tool[self.current_tool_id] += json_increment + + return DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(name=None, arguments=json_increment), + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + self._buffer += delta_text + current_text = self._buffer + # Check if we have a tool call + has_tool_call = self.tool_call_start_token in current_text + + if not has_tool_call: + # Check if buffer could be the start of a tool call + # Keep buffer if it could be a partial match of bot_token + is_potential_start = any( + self.tool_call_start_token.startswith(current_text[-i:]) + for i in range(1, min(len(current_text), len(self.tool_call_start_token)) + 1) + ) + + if not is_potential_start: + # Not a potential tool call, return as normal text + # Must return the entire buffer (current_text), not just new_text, + # because buffer may contain previously accumulated characters like '<' + # that turned out not to be part of a tool call + output_text = current_text + self._buffer = "" + if self.tool_call_end_token in output_text: + output_text = output_text.replace(self.tool_call_end_token, "") + return DeltaMessage(content=output_text) + else: + # Could be start of tool call, keep buffering + return None + + # Extract any text before the first bot_token and return it as normal_text + output_text = "" + first_bot_token_idx = current_text.find(self.tool_call_start_token) + if first_bot_token_idx > 0: + output_text= current_text[:first_bot_token_idx] + current_text = current_text[first_bot_token_idx:] + # Update buffer to only include from the bot token onwards + self._buffer = current_text + if not hasattr(self, "_tool_indices"): + self._tool_indices += 1 + + calls: list[DeltaToolCall] = [] + try: + # Try to match a partial or complete tool call + # Use a single flexible regex pattern that handles all cases + partial_match = re.search( + r"(.*?)(?:()|$)", + current_text, + re.DOTALL, + ) + if not partial_match: + return None + # return DeltaMessage(content=output_text, tool_calls=[]) + + # Extract match groups using helper method + func_name, func_args_raw, is_tool_end = self._extract_match_groups( + match=partial_match + ) + + # Initialize tool call state if needed (keeping existing logic) + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._streamed_raw_length = 0 + self.current_tool_name_sent = False # Reset for new tool call + self._reset_streaming_state() + # Check if this is a continuation of an existing tool call or a new one + elif not self.current_tool_name_sent: + # Only increment tool_id if we're truly starting a NEW tool call + # Don't increment if this is just the first time we're processing + # a tool call that was received in the buffer + # The key insight: only increment when we've COMPLETED a previous tool call + # and now see another bot_token in new_text + pass # Remove the problematic auto-increment logic + + # Ensure tracking arrays are large enough (keeping existing logic) + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Determine if function name is complete by checking for in the full text + # This is important for streaming scenarios where args come in later chunks + has_arg_key = "