304 lines
12 KiB
Python
304 lines
12 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
|
|||
|
|
import contextlib
|
|||
|
|
import json
|
|||
|
|
from collections.abc import Sequence
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
import regex as re
|
|||
|
|
|
|||
|
|
from vllm.entrypoints.openai.protocol import (
|
|||
|
|
ChatCompletionRequest,
|
|||
|
|
DeltaFunctionCall,
|
|||
|
|
DeltaMessage,
|
|||
|
|
DeltaToolCall,
|
|||
|
|
ExtractedToolCallInformation,
|
|||
|
|
FunctionCall,
|
|||
|
|
ToolCall,
|
|||
|
|
)
|
|||
|
|
from vllm.logger import init_logger
|
|||
|
|
from vllm.tokenizers import TokenizerLike
|
|||
|
|
from vllm.tool_parsers.abstract_tool_parser import (
|
|||
|
|
ToolParser,
|
|||
|
|
)
|
|||
|
|
from vllm.utils import random_uuid
|
|||
|
|
|
|||
|
|
logger = init_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Step3ToolParser(ToolParser):
|
|||
|
|
"""
|
|||
|
|
Tool parser for a model that uses a specific XML-like format for tool calls.
|
|||
|
|
This version uses a robust, stateful, cursor-based streaming parser and
|
|||
|
|
consolidates tool arguments into a single message.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
|
|||
|
|
TOOL_CALLS_END = "<|tool_calls_end|>"
|
|||
|
|
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
|
|||
|
|
TOOL_CALL_END = "<|tool_call_end|>"
|
|||
|
|
TOOL_SEP = "<|tool_sep|>"
|
|||
|
|
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
|
|||
|
|
|
|||
|
|
def __init__(self, tokenizer: TokenizerLike):
|
|||
|
|
super().__init__(tokenizer)
|
|||
|
|
self.position = 0
|
|||
|
|
# Explicit state flags for robust streaming
|
|||
|
|
self.tool_block_started = False
|
|||
|
|
self.tool_block_finished = False
|
|||
|
|
|
|||
|
|
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
|||
|
|
request = super().adjust_request(request)
|
|||
|
|
if request.tools and request.tool_choice != "none":
|
|||
|
|
request.skip_special_tokens = False
|
|||
|
|
return request
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_steptml_invoke(
|
|||
|
|
action_text: str,
|
|||
|
|
) -> tuple[str | None, dict[str, str] | None]:
|
|||
|
|
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
|
|||
|
|
if not func_name_match:
|
|||
|
|
return None, None
|
|||
|
|
func_name = func_name_match.group(1)
|
|||
|
|
|
|||
|
|
params: dict[str, str] = {}
|
|||
|
|
param_matches = re.findall(
|
|||
|
|
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
|
|||
|
|
action_text,
|
|||
|
|
)
|
|||
|
|
for name, value in param_matches:
|
|||
|
|
params[name] = value.strip()
|
|||
|
|
return func_name, params
|
|||
|
|
|
|||
|
|
def _cast_arguments(
|
|||
|
|
self,
|
|||
|
|
func_name: str,
|
|||
|
|
params: dict[str, Any],
|
|||
|
|
request: ChatCompletionRequest,
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
for tool in request.tools or []:
|
|||
|
|
if tool.function.name == func_name:
|
|||
|
|
schema = tool.function.parameters or {}
|
|||
|
|
properties = schema.get("properties", {})
|
|||
|
|
for key, value in params.items():
|
|||
|
|
if not isinstance(value, str):
|
|||
|
|
continue
|
|||
|
|
prop = properties.get(key, {})
|
|||
|
|
typ = prop.get("type")
|
|||
|
|
if typ == "string":
|
|||
|
|
params[key] = value.strip()
|
|||
|
|
elif typ == "integer":
|
|||
|
|
with contextlib.suppress(ValueError):
|
|||
|
|
params[key] = int(value)
|
|||
|
|
elif typ == "number":
|
|||
|
|
with contextlib.suppress(ValueError):
|
|||
|
|
params[key] = float(value)
|
|||
|
|
elif typ == "boolean":
|
|||
|
|
lower_val = value.lower()
|
|||
|
|
params[key] = (
|
|||
|
|
lower_val == "true"
|
|||
|
|
if lower_val in ("true", "false")
|
|||
|
|
else value
|
|||
|
|
)
|
|||
|
|
elif typ == "null":
|
|||
|
|
params[key] = None if value.lower() == "null" else value
|
|||
|
|
break
|
|||
|
|
return params
|
|||
|
|
|
|||
|
|
def extract_tool_calls_streaming(
|
|||
|
|
self,
|
|||
|
|
previous_text: str,
|
|||
|
|
current_text: str,
|
|||
|
|
delta_text: str,
|
|||
|
|
previous_token_ids: Sequence[int],
|
|||
|
|
current_token_ids: Sequence[int],
|
|||
|
|
delta_token_ids: Sequence[int],
|
|||
|
|
request: ChatCompletionRequest,
|
|||
|
|
) -> DeltaMessage | None:
|
|||
|
|
# The main loop processes the stream from the last known position.
|
|||
|
|
while True:
|
|||
|
|
if self.position >= len(current_text):
|
|||
|
|
return None # We've processed the entire stream.
|
|||
|
|
|
|||
|
|
unprocessed_text = current_text[self.position :]
|
|||
|
|
|
|||
|
|
# STATE: After all tools are done, all subsequent text is content.
|
|||
|
|
if self.tool_block_finished:
|
|||
|
|
self.position = len(current_text)
|
|||
|
|
return DeltaMessage(content=unprocessed_text)
|
|||
|
|
|
|||
|
|
# STATE: Before the tool block has started.
|
|||
|
|
if not self.tool_block_started:
|
|||
|
|
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
|
|||
|
|
self.position += len(self.TOOL_CALLS_BEGIN)
|
|||
|
|
self.tool_block_started = True
|
|||
|
|
continue # Token consumed, re-loop.
|
|||
|
|
|
|||
|
|
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
|
|||
|
|
if start_pos == -1:
|
|||
|
|
if (
|
|||
|
|
self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
|
|||
|
|
and unprocessed_text
|
|||
|
|
):
|
|||
|
|
return None # It's a prefix, wait.
|
|||
|
|
self.position = len(current_text)
|
|||
|
|
return DeltaMessage(content=unprocessed_text)
|
|||
|
|
else:
|
|||
|
|
content = unprocessed_text[:start_pos]
|
|||
|
|
self.position += len(content)
|
|||
|
|
return DeltaMessage(content=content)
|
|||
|
|
|
|||
|
|
# STATE: Inside the main tool block.
|
|||
|
|
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
|
|||
|
|
unprocessed_text = unprocessed_text.lstrip()
|
|||
|
|
self.position += offset
|
|||
|
|
|
|||
|
|
if unprocessed_text.startswith(self.TOOL_CALLS_END):
|
|||
|
|
self.position += len(self.TOOL_CALLS_END)
|
|||
|
|
self.tool_block_finished = True
|
|||
|
|
self.current_tool_id = -1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Check if we are between tool calls.
|
|||
|
|
tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
|
|||
|
|
self.current_tool_id
|
|||
|
|
].get("finished")
|
|||
|
|
if self.current_tool_id == -1 or tool_finished:
|
|||
|
|
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
|
|||
|
|
self.position += len(self.TOOL_CALL_BEGIN)
|
|||
|
|
if self.current_tool_id == -1:
|
|||
|
|
self.current_tool_id = 0
|
|||
|
|
else:
|
|||
|
|
self.current_tool_id += 1
|
|||
|
|
self.current_tool_name_sent = False
|
|||
|
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
|||
|
|
self.prev_tool_call_arr.append({})
|
|||
|
|
self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# STATE: Parsing an active tool call.
|
|||
|
|
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
|
|||
|
|
self.current_tool_id
|
|||
|
|
].get("finished", False):
|
|||
|
|
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
|
|||
|
|
if end_tool_pos == -1:
|
|||
|
|
tool_body = unprocessed_text
|
|||
|
|
else:
|
|||
|
|
tool_body = unprocessed_text[:end_tool_pos]
|
|||
|
|
|
|||
|
|
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
function_name, arguments = self._parse_steptml_invoke(tool_body)
|
|||
|
|
if not function_name:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
tool_call_arr = {"name": function_name, "parameters": arguments or {}}
|
|||
|
|
|
|||
|
|
# Send the function name as soon as it's parsed.
|
|||
|
|
if not self.current_tool_name_sent:
|
|||
|
|
self.current_tool_name_sent = True
|
|||
|
|
self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
|
|||
|
|
return DeltaMessage(
|
|||
|
|
tool_calls=[
|
|||
|
|
DeltaToolCall(
|
|||
|
|
index=self.current_tool_id,
|
|||
|
|
type="function",
|
|||
|
|
id=f"chatcmpl-tool-{random_uuid()}",
|
|||
|
|
function=DeltaFunctionCall(name=function_name),
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Update our internal state with the latest parsed arguments.
|
|||
|
|
self.prev_tool_call_arr[self.current_tool_id].update( # noqa: E501
|
|||
|
|
tool_call_arr
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Only send arguments when the tool call is complete.
|
|||
|
|
if end_tool_pos != -1:
|
|||
|
|
self.position += end_tool_pos + len(self.TOOL_CALL_END)
|
|||
|
|
self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
|
|||
|
|
|
|||
|
|
final_args = self._cast_arguments(
|
|||
|
|
function_name,
|
|||
|
|
tool_call_arr.get("parameters", {}), # type: ignore
|
|||
|
|
request,
|
|||
|
|
)
|
|||
|
|
if final_args:
|
|||
|
|
final_args_json = json.dumps(final_args, ensure_ascii=False)
|
|||
|
|
return DeltaMessage(
|
|||
|
|
tool_calls=[
|
|||
|
|
DeltaToolCall(
|
|||
|
|
index=self.current_tool_id,
|
|||
|
|
function=DeltaFunctionCall(
|
|||
|
|
arguments=final_args_json
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# If tool is not finished, return None to wait for more tokens.
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def extract_tool_calls(
|
|||
|
|
self,
|
|||
|
|
model_output: str,
|
|||
|
|
request: ChatCompletionRequest,
|
|||
|
|
) -> ExtractedToolCallInformation:
|
|||
|
|
if self.TOOL_CALLS_BEGIN not in model_output:
|
|||
|
|
return ExtractedToolCallInformation(
|
|||
|
|
tools_called=False, tool_calls=[], content=model_output
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
|
|||
|
|
if self.TOOL_CALLS_END not in rest:
|
|||
|
|
return ExtractedToolCallInformation(
|
|||
|
|
tools_called=False, tool_calls=[], content=model_output
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
|
|||
|
|
content = (pre_text + post_text).strip()
|
|||
|
|
|
|||
|
|
tool_calls: list[ToolCall] = []
|
|||
|
|
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
|
|||
|
|
|
|||
|
|
for part in call_parts:
|
|||
|
|
if not part or self.TOOL_CALL_END not in part:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
call_content = part.split(self.TOOL_CALL_END, 1)[0]
|
|||
|
|
if self.TOOL_SEP not in call_content:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
|
|||
|
|
if type_part.strip() != "function":
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
function_name, params_dict = self._parse_steptml_invoke(invoke_part)
|
|||
|
|
|
|||
|
|
if function_name and params_dict is not None:
|
|||
|
|
params_dict = self._cast_arguments(function_name, params_dict, request)
|
|||
|
|
params_str = json.dumps(params_dict, ensure_ascii=False)
|
|||
|
|
tool_calls.append(
|
|||
|
|
ToolCall(
|
|||
|
|
function=FunctionCall(name=function_name, arguments=params_str)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
if tool_calls:
|
|||
|
|
return ExtractedToolCallInformation(
|
|||
|
|
tools_called=True,
|
|||
|
|
tool_calls=tool_calls,
|
|||
|
|
content=content if content else None,
|
|||
|
|
)
|
|||
|
|
return ExtractedToolCallInformation(
|
|||
|
|
tools_called=False, tool_calls=[], content=model_output
|
|||
|
|
)
|