Sync from v0.13
This commit is contained in:
420
vllm/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
420
vllm/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501, SIM102
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import consume_space
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HunyuanA13BToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
self.prev_tool_calls: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args: list[str] = [] # Track arguments sent for each tool
|
||||
|
||||
# For backward compatibility with tests
|
||||
self.current_tools_sent: list[bool] = []
|
||||
|
||||
# For backward compatibility with serving code
|
||||
self.prev_tool_call_arr = []
|
||||
|
||||
# Regex patterns for preprocessing
|
||||
self.answer_tool_calls_pattern = re.compile(
|
||||
r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL
|
||||
)
|
||||
|
||||
self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
|
||||
|
||||
self.tool_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
|
||||
)
|
||||
|
||||
# TODO: not support nested json object in fc arguments.
|
||||
self.tool_non_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
)
|
||||
|
||||
self.bot_string = "<tool_calls>"
|
||||
|
||||
# Define streaming state type to be initialized later
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def preprocess_model_output(
|
||||
self, model_output: str
|
||||
) -> tuple[str | None, str | None]:
|
||||
# find the location tool call
|
||||
for match in self.answer_tool_calls_pattern.finditer(model_output):
|
||||
start, end = match.span()
|
||||
# check tool_calls whether in side of <think>
|
||||
think_regions = [
|
||||
(m.start(), m.end())
|
||||
for m in re.finditer(
|
||||
r"<think>(.*?)</think>", model_output, flags=re.DOTALL
|
||||
)
|
||||
]
|
||||
in_think = any(
|
||||
start > t_start and end < t_end for t_start, t_end in think_regions
|
||||
)
|
||||
if not in_think:
|
||||
content = model_output[:start]
|
||||
tool_calls_content = match.group(1).strip()
|
||||
try:
|
||||
json.loads(tool_calls_content)
|
||||
return content, tool_calls_content
|
||||
except Exception:
|
||||
continue
|
||||
return model_output, None
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from a complete model output.
|
||||
"""
|
||||
try:
|
||||
# Preprocess the model output
|
||||
content, potential_tool_calls = self.preprocess_model_output(model_output)
|
||||
|
||||
if not potential_tool_calls:
|
||||
# some text should be filtered out for no function call
|
||||
# this text is in a13b's chat template.
|
||||
if content:
|
||||
content = content.replace("助手:", "", 1)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=content
|
||||
)
|
||||
|
||||
# Parse the potential tool calls as JSON
|
||||
tool_calls_data = json.loads(potential_tool_calls)
|
||||
|
||||
# Ensure it's an array
|
||||
if not isinstance(tool_calls_data, list):
|
||||
logger.debug("Tool calls data is not an array")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content or model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for idx, call in enumerate(tool_calls_data):
|
||||
if (
|
||||
not isinstance(call, dict)
|
||||
or "name" not in call
|
||||
or "arguments" not in call
|
||||
):
|
||||
continue
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{random_uuid()}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=call["name"],
|
||||
arguments=(
|
||||
json.dumps(call["arguments"])
|
||||
if isinstance(call["arguments"], dict)
|
||||
else call["arguments"]
|
||||
),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not content or len(content.strip()) == 0:
|
||||
# clear the whitespace content.
|
||||
content = None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string), current_text)
|
||||
if (
|
||||
not current_text
|
||||
or start_idx >= len(current_text)
|
||||
or current_text[start_idx] != "["
|
||||
):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
self._try_parse_json_tools(current_text[start_idx:])
|
||||
|
||||
test_delta = self._handle_test_compatibility(current_text)
|
||||
if test_delta:
|
||||
return test_delta
|
||||
|
||||
name_matches = list(self.tool_name_reg.finditer(current_text))
|
||||
tool_count = len(name_matches)
|
||||
if tool_count == 0:
|
||||
return None
|
||||
self._ensure_state_arrays(tool_count)
|
||||
current_idx = self.streaming_state["current_tool_index"]
|
||||
|
||||
name_delta = self._handle_tool_name_streaming(
|
||||
current_idx, tool_count, name_matches
|
||||
)
|
||||
if name_delta:
|
||||
return name_delta
|
||||
|
||||
args_delta = self._handle_tool_args_streaming(
|
||||
current_text, current_idx, tool_count
|
||||
)
|
||||
if args_delta:
|
||||
return args_delta
|
||||
|
||||
return None
|
||||
|
||||
def _try_parse_json_tools(self, current_text: str):
|
||||
try:
|
||||
parsed_tools = json.loads(current_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def _handle_test_compatibility(self, current_text: str):
|
||||
if len(self.current_tools_sent) > 0:
|
||||
if (
|
||||
len(self.current_tools_sent) == 1
|
||||
and self.current_tools_sent[0] is False
|
||||
):
|
||||
name_match = self.tool_name_reg.search(current_text)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
tool_id = f"chatcmpl-tool-{random_uuid()}"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tools_sent = [True]
|
||||
self.current_tool_id = 0
|
||||
self.streaming_state["current_tool_index"] = 0
|
||||
if len(self.streaming_state["sent_tools"]) == 0:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": True,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.streaming_state["sent_tools"][0]["sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _ensure_state_arrays(self, tool_count: int):
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": False,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
while len(self.streaming_state["tool_ids"]) < tool_count:
|
||||
self.streaming_state["tool_ids"].append(None)
|
||||
|
||||
def _handle_tool_name_streaming(
|
||||
self, current_idx: int, tool_count: int, name_matches
|
||||
):
|
||||
if current_idx == -1 or current_idx < tool_count - 1:
|
||||
next_idx = current_idx + 1
|
||||
if (
|
||||
next_idx < tool_count
|
||||
and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
|
||||
):
|
||||
self.streaming_state["current_tool_index"] = next_idx
|
||||
self.current_tool_id = next_idx
|
||||
current_idx = next_idx
|
||||
tool_name = name_matches[current_idx].group(1)
|
||||
tool_id = f"call_{current_idx}_{random_uuid()}"
|
||||
self.streaming_state["tool_ids"][current_idx] = tool_id
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(name=tool_name).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _handle_tool_args_streaming(
|
||||
self, current_text: str, current_idx: int, tool_count: int
|
||||
):
|
||||
if current_idx >= 0 and current_idx < tool_count:
|
||||
empty_args_match = self.tool_empty_arg_reg.search(current_text)
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
for i in range(tool_count):
|
||||
if i == current_idx:
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
]:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{}"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{}"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{}"
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
]
|
||||
return delta
|
||||
|
||||
args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text))
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
if not is_last_tool:
|
||||
next_tool_pos = current_text.find(
|
||||
"},{", args_matches[current_idx].start()
|
||||
)
|
||||
if next_tool_pos != -1:
|
||||
args_end_pos = next_tool_pos + 1
|
||||
args_text = (
|
||||
current_text[
|
||||
args_matches[current_idx].start() : args_end_pos
|
||||
]
|
||||
.split('"arguments":')[1]
|
||||
.strip()
|
||||
)
|
||||
sent_args = self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
]
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] and args_text.startswith("{"):
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(arguments="{").model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
if args_text.startswith(sent_args):
|
||||
args_diff = args_text[len(sent_args) :]
|
||||
if args_diff:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = args_text
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += args_diff
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
if args_text.endswith("}") and args_text == sent_args:
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
]
|
||||
return None
|
||||
Reference in New Issue
Block a user