# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501, SIM102 import json from collections.abc import Sequence from typing import Any, Optional, Union import regex as re from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @ToolParserManager.register_module("hunyuan_a13b") class HunyuanA13BToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) # Initialize state for streaming mode self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False self.streamed_args: list[str] = [ ] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] # For backward compatibility with serving code self.prev_tool_call_arr = [] # Regex patterns for preprocessing self.answer_tool_calls_pattern = re.compile( r"([\s\S]*?)", re.DOTALL) self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') self.tool_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') # TODO: not support nested json object in fc arguments. self.tool_non_empty_arg_reg = re.compile( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' ) self.bot_string = "" # Define streaming state type to be initialized later self.streaming_state: dict[str, Any] = { "current_tool_index": -1, "tool_ids": [], "sent_tools": [], } def preprocess_model_output( self, model_output: str) -> tuple[Optional[str], Optional[str]]: # find the location tool call for match in self.answer_tool_calls_pattern.finditer(model_output): start, end = match.span() # check tool_calls whether in side of think_regions = [(m.start(), m.end()) for m in re.finditer( r"(.*?)", model_output, flags=re.DOTALL)] in_think = any(start > t_start and end < t_end for t_start, t_end in think_regions) if not in_think: content = model_output[:start] tool_calls_content = match.group(1).strip() try: json.loads(tool_calls_content) return content, tool_calls_content except Exception: continue return model_output, None def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output content, potential_tool_calls = self.preprocess_model_output( model_output) if not potential_tool_calls: # some text should be filtered out for no function call # this text is in a13b's chat template. if content: content = content.replace("助手:", "", 1) return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=content) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) # Ensure it's an array if not isinstance(tool_calls_data, list): logger.debug("Tool calls data is not an array") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=content or model_output, ) tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): if (not isinstance(call, dict) or "name" not in call or "arguments" not in call): continue tool_call = ToolCall( id=f"call_{random_uuid()}", type="function", function=FunctionCall( name=call["name"], arguments=(json.dumps(call["arguments"]) if isinstance( call["arguments"], dict) else call["arguments"]), ), ) tool_calls.append(tool_call) if not content or len(content.strip()) == 0: # clear the whitespace content. content = None return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content, ) except Exception: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Extract tool calls for streaming mode. """ start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_string): start_idx = consume_space(start_idx + len(self.bot_string), current_text) if not current_text or start_idx >= len( current_text) or current_text[start_idx] != '[': return DeltaMessage(content=delta_text) self._try_parse_json_tools(current_text[start_idx:]) test_delta = self._handle_test_compatibility(current_text) if test_delta: return test_delta name_matches = list(self.tool_name_reg.finditer(current_text)) tool_count = len(name_matches) if tool_count == 0: return None self._ensure_state_arrays(tool_count) current_idx = self.streaming_state["current_tool_index"] name_delta = self._handle_tool_name_streaming(current_idx, tool_count, name_matches) if name_delta: return name_delta args_delta = self._handle_tool_args_streaming(current_text, current_idx, tool_count) if args_delta: return args_delta return None def _try_parse_json_tools(self, current_text: str): try: parsed_tools = json.loads(current_text) if isinstance(parsed_tools, list): self.prev_tool_call_arr = parsed_tools except json.JSONDecodeError: pass def _handle_test_compatibility(self, current_text: str): if len(self.current_tools_sent) > 0: if (len(self.current_tools_sent) == 1 and self.current_tools_sent[0] is False): name_match = self.tool_name_reg.search(current_text) if name_match: function_name = name_match.group(1) tool_id = f"chatcmpl-tool-{random_uuid()}" delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0, type="function", id=tool_id, function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), ) ]) self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: self.streaming_state["sent_tools"].append({ "sent_name": True, "sent_arguments_prefix": False, "sent_arguments": "", }) else: self.streaming_state["sent_tools"][0][ "sent_name"] = True self.current_tool_name_sent = True return delta return None def _ensure_state_arrays(self, tool_count: int): while len(self.streaming_state["sent_tools"]) < tool_count: self.streaming_state["sent_tools"].append({ "sent_name": False, "sent_arguments_prefix": False, "sent_arguments": "", }) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) def _handle_tool_name_streaming(self, current_idx: int, tool_count: int, name_matches): if current_idx == -1 or current_idx < tool_count - 1: next_idx = current_idx + 1 if (next_idx < tool_count and not self.streaming_state["sent_tools"][next_idx] ["sent_name"]): self.streaming_state["current_tool_index"] = next_idx self.current_tool_id = next_idx current_idx = next_idx tool_name = name_matches[current_idx].group(1) tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=current_idx, type="function", id=tool_id, function=DeltaFunctionCall(name=tool_name).model_dump( exclude_none=True), ) ]) self.streaming_state["sent_tools"][current_idx][ "sent_name"] = True self.current_tool_name_sent = True while len(self.streamed_args) <= current_idx: self.streamed_args.append("") return delta return None def _handle_tool_args_streaming(self, current_text: str, current_idx: int, tool_count: int): if current_idx >= 0 and current_idx < tool_count: empty_args_match = self.tool_empty_arg_reg.search(current_text) if empty_args_match and empty_args_match.start() > 0: for i in range(tool_count): if i == current_idx: if not self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix"]: self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix"] = True self.streaming_state["sent_tools"][current_idx][ "sent_arguments"] = "{}" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments="{}").model_dump( exclude_none=True), ) ]) if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ "current_tool_index"] return delta args_matches = list( self.tool_non_empty_arg_reg.finditer(current_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) is_last_tool = current_idx == tool_count - 1 if not is_last_tool: next_tool_pos = current_text.find( "},{", args_matches[current_idx].start()) if next_tool_pos != -1: args_end_pos = (next_tool_pos + 1) args_text = ( current_text[args_matches[current_idx].start( ):args_end_pos].split('"arguments":')[1].strip()) sent_args = self.streaming_state["sent_tools"][current_idx][ "sent_arguments"] if not self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix"] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ "sent_arguments_prefix"] = True self.streaming_state["sent_tools"][current_idx][ "sent_arguments"] = "{" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments="{").model_dump(exclude_none=True), ) ]) return delta if args_text.startswith(sent_args): args_diff = args_text[len(sent_args):] if args_diff: self.streaming_state["sent_tools"][current_idx][ "sent_arguments"] = args_text while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=current_idx, function=DeltaFunctionCall( arguments=args_diff).model_dump( exclude_none=True), ) ]) return delta if args_text.endswith("}") and args_text == sent_args: if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ "current_tool_index"] return None