Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
__all__ = ["ToolParser", "ToolParserManager"]
"""
Register a lazy module mapping.
Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
_TOOL_PARSERS_TO_REGISTER = {
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}
def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)
register_lazy_tool_parsers()

View File

@@ -0,0 +1,273 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import os
from collections.abc import Callable, Sequence
from functools import cached_property
from openai.types.responses.response_format_text_json_schema_config import (
ResponseFormatTextJSONSchemaConfig,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
ResponsesRequest,
ResponseTextConfig,
)
from vllm.logger import init_logger
from vllm.sampling_params import (
StructuredOutputsParams,
)
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.utils import get_json_schema_from_tools
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path
logger = init_logger(__name__)
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
if not request.tools:
return request
json_schema_from_tool = get_json_schema_from_tools(
tool_choice=request.tool_choice, tools=request.tools
)
# Set structured output params for tool calling
if json_schema_from_tool is not None:
if isinstance(request, ChatCompletionRequest):
request.structured_outputs = StructuredOutputsParams()
# tool_choice: "Forced Function" or "required" will override
# structured output json settings to make tool calling work correctly
request.structured_outputs.json = json_schema_from_tool
if isinstance(request, ResponsesRequest):
request.text = ResponseTextConfig()
request.text.format = ResponseFormatTextJSONSchemaConfig(
name="tool_calling_response",
schema=json_schema_from_tool,
type="json_schema",
description="Response format for tool calling",
strict=True,
)
return request
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls has not been implemented!"
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
class ToolParserManager:
"""
Central registry for ToolParser implementations.
Supports two modes:
- Eager (immediate) registration via `register_module`
- Lazy registration via `register_lazy_module`
"""
tool_parsers: dict[str, type[ToolParser]] = {}
lazy_parsers: dict[str, tuple[str, str]] = {} # name -> (module_path, class_name)
@classmethod
def get_tool_parser(cls, name: str) -> type[ToolParser]:
"""
Retrieve a registered or lazily registered ToolParser class.
If the parser is lazily registered,
it will be imported and cached on first access.
Raises KeyError if not found.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
if name in cls.lazy_parsers:
return cls._load_lazy_parser(name)
raise KeyError(f"Tool parser '{name}' not found.")
@classmethod
def _load_lazy_parser(cls, name: str) -> type[ToolParser]:
"""Import and register a lazily loaded parser."""
module_path, class_name = cls.lazy_parsers[name]
try:
mod = importlib.import_module(module_path)
parser_cls = getattr(mod, class_name)
if not issubclass(parser_cls, ToolParser):
raise TypeError(
f"{class_name} in {module_path} is not a ToolParser subclass."
)
cls.tool_parsers[name] = parser_cls # cache
return parser_cls
except Exception as e:
logger.exception(
"Failed to import lazy tool parser '%s' from %s: %s",
name,
module_path,
e,
)
raise
@classmethod
def _register_module(
cls,
module: type[ToolParser],
module_name: str | list[str] | None = None,
force: bool = True,
) -> None:
"""Register a ToolParser class immediately."""
if not issubclass(module, ToolParser):
raise TypeError(
f"module must be subclass of ToolParser, but got {type(module)}"
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_names = [module_name]
elif is_list_of(module_name, str):
module_names = module_name
else:
raise TypeError("module_name must be str, list[str], or None.")
for name in module_names:
if not force and name in cls.tool_parsers:
existed = cls.tool_parsers[name]
raise KeyError(f"{name} is already registered at {existed.__module__}")
cls.tool_parsers[name] = module
@classmethod
def register_lazy_module(cls, name: str, module_path: str, class_name: str) -> None:
"""
Register a lazy module mapping.
Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
cls.lazy_parsers[name] = (module_path, class_name)
@classmethod
def register_module(
cls,
name: str | list[str] | None = None,
force: bool = True,
module: type[ToolParser] | None = None,
) -> type[ToolParser] | Callable[[type[ToolParser]], type[ToolParser]]:
"""
Register module immediately or lazily (as a decorator).
Usage:
@ToolParserManager.register_module("kimi_k2")
class KimiK2ToolParser(ToolParser):
...
Or:
ToolParserManager.register_module(module=SomeToolParser)
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# Immediate registration
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# Decorator usage
def _decorator(obj: type[ToolParser]) -> type[ToolParser]:
module_path = obj.__module__
class_name = obj.__name__
if isinstance(name, str):
names = [name]
elif name is not None and is_list_of(name, str):
names = name
else:
names = [class_name]
for n in names:
# Lazy mapping only: do not import now
cls.lazy_parsers[n] = (module_path, class_name)
return obj
return _decorator
@classmethod
def list_registered(cls) -> list[str]:
"""Return names of all eagerly and lazily registered tool parsers."""
return sorted(set(cls.tool_parsers.keys()) | set(cls.lazy_parsers.keys()))
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""Import a user-defined parser file from arbitrary path."""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
logger.exception(
"Failed to load module '%s' from %s.", module_name, plugin_path
)

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool▁calls▁begin>"
self.tool_calls_end_token: str = "<tool▁calls▁end>"
self.tool_call_start_token: str = "<tool▁call▁begin>"
self.tool_call_end_token: str = "<tool▁call▁end>"
self.tool_call_regex = re.compile(
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>"
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<function_name>.*)<tool▁sep>(?P<function_arguments>.*)"
)
self.stream_tool_call_name_regex = re.compile(
r"(?P<function_name>.*)<tool▁sep>"
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"DeepSeek-V3.1 Tool parser could not locate tool call "
"start/end tokens in the tokenizer!"
)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(model_output)
tool_calls = []
for match in function_call_tuples:
function_name, function_args = match
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=function_args
),
)
)
content = model_output[: model_output.find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_calls_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
self.tool_calls_end_token, ""
)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id
)
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id
)
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
tool_call_portion
)
if current_tool_call_matches:
tool_name, tool_args = current_tool_call_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(tool_call_portion)
)
if current_tool_call_name_matches:
tool_name = current_tool_call_name_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: str | None = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (
isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)
):
delta_arguments = cur_arguments[len(prev_arguments) :]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,591 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class DeepSeekV32ToolParser(ToolParser):
"""
example tool call content:
<DSMLfunction_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">杭州</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">北京</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
</DSMLfunction_calls>
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
# Sentinel tokens
self.dsml_token: str = "DSML"
self.dsml_start_check: str = "<" + self.dsml_token
self.tool_call_start_token: str = "<DSMLfunction_calls>"
self.tool_call_end_token: str = "</DSMLfunction_calls>"
self.invoke_start_prefix: str = "<DSMLinvoke name="
self.invoke_end_token: str = "</DSMLinvoke>"
self.parameter_prefix: str = "<DSMLparameter name="
self.parameter_end_token: str = "</DSMLparameter>"
# Streaming state variables
self.current_tool_name_sent: bool = False
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Initialize streaming state variables
self.current_tool_index: int = 0
self.invoke_index: int = 0
self.header_sent: bool = False
self.current_function_name: str | None = None
self.current_param_name: str | None = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.json_started: bool = False
self.json_closed: bool = False
self.accumulated_params: dict = {}
self.streaming_request: ChatCompletionRequest | None = None
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
r"<DSMLfunction_calls>(.*?)</DSMLfunction_calls>", re.DOTALL
)
self.invoke_complete_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>(.*?)</DSMLinvoke>', re.DOTALL
)
self.parameter_complete_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</DSMLparameter>',
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.invoke_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
# Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear()
def _parse_invoke_params(self, invoke_str: str) -> dict | None:
param_dict = dict()
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls = []
# Find all complete tool_call blocks
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
tool_call_match
):
param_dict = self._parse_invoke_params(invoke_content)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(param_dict, ensure_ascii=False),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Extract content before first tool call
first_tool_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string."""
name_str = name_str.strip()
if (
name_str.startswith('"')
and name_str.endswith('"')
or name_str.startswith("'")
and name_str.endswith("'")
):
return name_str[1:-1]
return name_str
def _extract_param_name(self, input_str: str) -> str:
"""Extract param name"""
start = input_str.find('"') + 1
end = input_str.find('"', start)
return input_str[start:end] if start > 0 and end > start else input_str
def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
try:
return int(value)
except (ValueError, TypeError):
return value
elif param_type in ["number", "float"]:
try:
val = float(value)
return val if val != int(val) else int(val)
except (ValueError, TypeError):
return value
elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
# Try JSON parse first, fallback to string
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output."""
# Store request for type conversion
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
invoke_ends = current_text.count(self.invoke_end_token)
if invoke_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.in_function = False # Now we can safely set this to False
self.accumulated_params = {}
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if self.dsml_token in current_text:
self.is_tool_call_started = True
# Return any content before the tool call
if self.dsml_start_check in delta_text:
content_before = delta_text[
: delta_text.index(self.dsml_start_check)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
if delta_text.endswith("<"):
return DeltaMessage(content=delta_text[:-1])
if previous_text and previous_text.endswith("<"):
return DeltaMessage(content="<" + delta_text)
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count = current_text.count(self.invoke_start_prefix)
if self.current_tool_index >= invoke_starts_count:
# We're past all tool calls, shouldn't be here
return None
# Find the current tool call portion
invoke_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.invoke_start_prefix, idx)
if idx == -1:
break
invoke_start_positions.append(idx)
idx += len(self.invoke_start_prefix)
if self.current_tool_index >= len(invoke_start_positions):
# No more tool calls to process yet
return None
invoke_start_idx = invoke_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
if invoke_end_idx == -1:
tool_text = current_text[invoke_start_idx:]
else:
tool_text = current_text[
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
]
# Looking for function header
if not self.header_sent:
if self.invoke_start_prefix in tool_text:
func_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
# Find the end quote for the function name
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
function_name_raw = tool_text[func_start:func_end]
self.current_function_name = self._extract_name(function_name_raw)
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if len(self.prev_tool_call_arr) <= self.current_tool_index:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if self.in_function and not self.json_started:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.invoke_end_token in tool_text:
# Count total parameters in the tool text
total_param_count = tool_text.count(self.parameter_prefix)
# Only close JSON if all parameters have been processed
if self.param_count >= total_param_count:
# Close JSON
self.json_closed = True
# Extract complete tool call
# Find the invoke content
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
invoke_content_end = tool_text.find(
self.invoke_end_token, invoke_start
)
if invoke_content_end != -1:
invoke_content = tool_text[invoke_start:invoke_content_end]
# Parse to get the complete arguments
try:
invoke_params = self._parse_invoke_params(invoke_content)
if invoke_params and self.current_tool_index < len(
self.prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
self.prev_tool_call_arr[self.current_tool_index][
"arguments"
] = json.dumps(invoke_params, ensure_ascii=False)
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.json_closed = True
self.in_function = False
self.accumulated_params = {}
logger.debug("[M2_STREAMING] Tool call completed")
return result
else:
# Don't close JSON yet, continue processing parameters
return None
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
param_name_raw = remaining[:name_end]
self.current_param_name = self._extract_param_name(param_name_raw)
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.invoke_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.invoke_end_token in tool_text:
# Tool call and parameter is complete
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = {}
if self.streaming_request and self.streaming_request.tools:
for tool in self.streaming_request.tools:
if (
hasattr(tool, "function")
and tool.function.name == self.current_function_name
and hasattr(tool.function, "parameters")
):
params = tool.function.parameters
if (
isinstance(params, dict)
and "properties" in params
):
param_config = params["properties"]
break
# Get parameter type
param_type = "string"
if (
self.current_param_name in param_config
and isinstance(param_config[self.current_param_name], dict)
and "type" in param_config[self.current_param_name]
):
param_type = param_config[self.current_param_name]["type"]
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value, param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
return None

View File

@@ -0,0 +1,390 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool▁calls▁begin>"
self.tool_calls_end_token: str = "<tool▁calls▁end>"
self.tool_call_start_token: str = "<tool▁call▁begin>"
self.tool_call_end_token: str = "<tool▁call▁end>"
self.tool_call_regex = re.compile(
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>"
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])"
)
self.stream_tool_call_name_regex = re.compile(
r"(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n"
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"DeepSeek-V3 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(model_output)
tool_calls = []
for match in function_call_tuples:
tool_type, function_name, function_args = match
tool_calls.append(
ToolCall(
type=tool_type,
function=FunctionCall(
name=function_name, arguments=function_args
),
)
)
content = model_output[: model_output.find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_calls_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
self.tool_calls_end_token, ""
)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id
)
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id
)
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
tool_call_portion
)
if current_tool_call_matches:
tool_type, tool_name, tool_args = current_tool_call_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(tool_call_portion)
)
if current_tool_call_name_matches:
tool_type, tool_name = current_tool_call_name_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: str | None = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (
isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)
):
delta_arguments = cur_arguments[len(prev_arguments) :]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
"""
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1
self.streamed_args_for_tool: list[str] = []
self.think_end_token = "</think>"
self.response_start_token: str = "<response>"
self.response_end_token: str = "</response>"
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
self.tool_calls_start_token = self.tool_call_start_token
self.newline_token: str = "<0x0A>"
self.tool_call_regex = re.compile(
r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.think_end_token_id = self.vocab.get(self.think_end_token)
self.response_start_token_id = self.vocab.get(self.response_start_token)
self.response_end_token_id = self.vocab.get(self.response_end_token)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.newline_token_id = self.vocab.get(self.newline_token)
self.parser_token_ids = [
self.think_end_token_id,
self.response_start_token_id,
self.response_end_token_id,
]
self._buffer = ""
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
tool_call_json_list = self.tool_call_regex.findall(model_output)
tool_calls = []
for tool_call_json in tool_call_json_list:
tool_call_dict = json.loads(tool_call_json)
args_str = json.dumps(
tool_call_dict.get("arguments", {}), ensure_ascii=False
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tool_call_dict.get("name", ""),
arguments=args_str,
),
)
)
content = model_output[
: model_output.find(self.tool_calls_start_token)
].rstrip("\n")
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
self._buffer += delta_text
cur_text = self._buffer
start_idx = cur_text.find(self.tool_call_start_token)
if start_idx == -1:
self._buffer = ""
# At least one toolcall has been completed
if self.current_tool_id > 0:
cur_text = ""
if self.current_tool_id == -1 and all(
token_id == self.newline_token_id for token_id in previous_token_ids
):
cur_text = cur_text.strip("\n")
# handle <response> </response> when tool_call is not triggered
# cur_text === delta_text
content = cur_text
if self.response_start_token_id in delta_token_ids:
content = content.lstrip("\n")
response_start_idx = content.find(self.response_start_token)
content = content[response_start_idx + len(self.response_start_token) :]
# if have </response>, remove it
response_end_idx = content.rfind(self.response_end_token)
if response_end_idx != -1:
content = content[:response_end_idx]
elif self.response_end_token_id in delta_token_ids:
response_end_idx = content.rfind(self.response_end_token)
content = content[:response_end_idx]
# remove \n after </think> or <response> or </response>
if (
len(previous_token_ids) > 0
and previous_token_ids[-1] in self.parser_token_ids
) and (
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
):
content = content.lstrip("\n")
return DeltaMessage(content=content if content else None)
logger.debug("cur_text = %s", cur_text)
end_idx = cur_text.find(self.tool_call_end_token)
if end_idx != -1:
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
extracted_tool_calls = self.extract_tool_calls(
cur_text[: end_idx + len(self.tool_call_end_token)], request
)
if len(extracted_tool_calls.tool_calls) == 0:
logger.warning("Failed to extract any tool calls.")
return None
tool_call = extracted_tool_calls.tool_calls[0]
self.prev_tool_call_arr[self.current_tool_id] = {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
self.streamed_args_for_tool[self.current_tool_id] = (
tool_call.function.arguments
)
delta = DeltaMessage(
content=extracted_tool_calls.content,
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=tool_call.id,
type=tool_call.type,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
],
)
self.current_tool_id += 1
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
return delta
self._buffer = cur_text[start_idx:]
content = cur_text[:start_idx].rstrip("\n")
return DeltaMessage(content=content if content else None)

View File

@@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
logger = init_logger(__name__)
REGEX_FUNCTION_CALL = re.compile(
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
re.DOTALL,
)
NAME_REGEX = re.compile(
r'"name"\s*:\s*"([^"]*)"',
re.DOTALL,
)
ARGS_REGEX = re.compile(
r'"arguments"\s*:\s*(.*)',
re.DOTALL,
)
class GigaChat3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_started: bool = False
self.tool_name_sent: bool = False
self.tool_id: str | None = None
self.prev_tool_call_arr: list[dict] = []
self.content_buffer: str = ""
self.trigger_start = "function call{"
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
match = REGEX_FUNCTION_CALL.search(model_output)
if not match:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
json_candidate = match.group(1).strip()
try:
data = json.loads(json_candidate)
except json.JSONDecodeError:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
name = data["name"]
args = data["arguments"]
if not isinstance(args, str):
args = json.dumps(args, ensure_ascii=False)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=name,
arguments=args,
),
)
]
prefix = model_output[: match.start()]
content = prefix.rstrip() if prefix and prefix.strip() else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
func_name = None
cur_args = None
if not self.tool_started:
match = REGEX_FUNCTION_CALL.search(current_text)
if match:
self.tool_started = True
self.content_buffer = ""
else:
self.content_buffer += delta_text
clean_buffer = self.content_buffer.lstrip()
is_prefix = self.trigger_start.startswith(clean_buffer)
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
if is_prefix or starts_with_trigger:
return None
else:
flush_text = self.content_buffer
self.content_buffer = ""
return DeltaMessage(content=flush_text)
match = REGEX_FUNCTION_CALL.search(current_text)
if not match:
return None
json_tail = match.group(1).strip()
name_match = NAME_REGEX.search(json_tail)
if name_match:
func_name = name_match.group(1)
args_match = ARGS_REGEX.search(json_tail)
if args_match:
cur_args = args_match.group(1).strip()
if cur_args.endswith("}"): # last '}' end of json
try:
candidate = cur_args[:-1].strip()
json.loads(candidate)
cur_args = candidate
except json.JSONDecodeError:
pass
if not self.prev_tool_call_arr:
self.prev_tool_call_arr.append({})
if not self.tool_name_sent:
if not func_name:
return None
self.tool_name_sent = True
self.tool_id = make_tool_call_id()
self.prev_tool_call_arr[0]["name"] = func_name
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id=self.tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
).model_dump(exclude_none=True),
)
],
content=None,
)
if cur_args is None:
return None
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
if not prev_args:
delta_args = cur_args
elif cur_args.startswith(prev_args):
delta_args = cur_args[len(prev_args) :]
else:
return None
if not delta_args:
return None
self.prev_tool_call_arr[0]["arguments"] = cur_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
function=DeltaFunctionCall(
arguments=delta_args,
).model_dump(exclude_none=True),
)
],
content=None,
)

View File

@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
self.tool_calls_start_token = self.tool_call_start_token
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
self.func_detail_regex = re.compile(
r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL
)
self.func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self._buffer = ""
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[ChatCompletionToolsParam] | None,
) -> bool:
if tools is None:
return False
for tool in tools:
if tool.function.name == tool_name:
if tool.function.parameters is None:
return False
arg_type = (
tool.function.parameters.get("properties", {})
.get(arg_name, {})
.get("type", None)
)
return arg_type == "string"
logger.debug("No tool named '%s'.", tool_name)
return False
def _deserialize(value: str) -> Any:
try:
return json.loads(value)
except Exception:
pass
try:
return ast.literal_eval(value)
except Exception:
pass
return value
matched_tool_calls = self.func_call_regex.findall(model_output)
logger.debug("model_output: %s", model_output)
try:
tool_calls = []
for match in matched_tool_calls:
tc_detail = self.func_detail_regex.search(match)
tc_name = tc_detail.group(1)
tc_args = tc_detail.group(2)
pairs = self.func_arg_regex.findall(tc_args)
arg_dct = {}
for key, value in pairs:
arg_key = key.strip()
arg_val = value.strip()
if not _is_string_type(tc_name, arg_key, request.tools):
arg_val = _deserialize(arg_val)
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
arg_dct[arg_key] = arg_val
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tc_name, arguments=json.dumps(arg_dct)
),
)
)
except Exception:
logger.exception("Failed to extract tool call spec")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
if len(tool_calls) > 0:
content = model_output[: model_output.find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
self._buffer += delta_text
cur_text = self._buffer
start_idx = cur_text.find(self.tool_call_start_token)
if start_idx == -1:
self._buffer = ""
if self.current_tool_id > 0:
cur_text = ""
return DeltaMessage(content=cur_text)
logger.debug("cur_text = %s", cur_text)
end_idx = cur_text.find(self.tool_call_end_token)
if end_idx != -1:
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
extracted_tool_calls = self.extract_tool_calls(
cur_text[: end_idx + len(self.tool_call_end_token)], request
)
if len(extracted_tool_calls.tool_calls) == 0:
logger.warning("Failed to extract any tool calls.")
return None
tool_call = extracted_tool_calls.tool_calls[0]
self.prev_tool_call_arr[self.current_tool_id] = {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
self.streamed_args_for_tool[self.current_tool_id] = (
tool_call.function.arguments
)
delta = DeltaMessage(
content=extracted_tool_calls.content,
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=tool_call.id,
type=tool_call.type,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
],
)
self.current_tool_id += 1
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
return delta
self._buffer = cur_text[start_idx:]
return DeltaMessage(content=cur_text[:start_idx])

View File

@@ -0,0 +1,273 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from json import JSONDecoder
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads,
)
logger = init_logger(__name__)
class Granite20bFCToolParser(ToolParser):
"""
Tool call parser for the granite-20b-functioncalling model intended
for use with the examples/tool_chat_template_granite20b_fc.jinja
template.
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.bot_token = "<function_call>"
self.tool_start_token = self.bot_token
self.tool_call_regex = re.compile(r"<function_call>\s*")
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
if self.tool_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
dec = JSONDecoder()
try:
matches = list(self.tool_call_regex.finditer(model_output))
logger.debug("Found %d tool call matches", len(matches))
raw_function_calls = []
for i, match in enumerate(matches):
# position after the <function_call> tag
start_of_json = match.end()
# end_index == the start of the next function call
# (if exists)
next_function_call_start = (
matches[i + 1].start() if i + 1 < len(matches) else None
)
raw_function_calls.append(
dec.raw_decode(
model_output[start_of_json:next_function_call_start]
)[0]
)
logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
for function_call in raw_function_calls
]
content = model_output[: model_output.find(self.bot_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if len(current_text) < len(self.bot_token) and self.bot_token.startswith(
current_text
):
return None
if not current_text.startswith(self.bot_token):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = len(self.bot_token)
start_idx = consume_space(start_idx, current_text)
while start_idx < len(current_text):
(obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
is_complete.append(
is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx
start_idx = consume_space(start_idx, current_text)
start_idx += len(self.bot_token)
start_idx = consume_space(start_idx, current_text)
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads,
)
logger = init_logger(__name__)
class GraniteToolParser(ToolParser):
"""
Tool call parser for the granite 3.0 models. Intended
for use with the examples/tool_chat_template_granite.jinja
template.
Used when --enable-auto-tool-choice --tool-call-parser granite
are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
self.bot_string = "<tool_call>"
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
stripped = (
model_output.strip()
.removeprefix(self.bot_token)
.removeprefix(self.bot_string)
.lstrip()
)
if not stripped or stripped[0] != "[":
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
raw_function_calls = json.loads(stripped)
if not isinstance(raw_function_calls, list):
raise Exception(
f"Expected dict or list, got {type(raw_function_calls)}"
)
logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
for function_call in raw_function_calls
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=None,
)
except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
start_idx = consume_space(0, current_text)
if current_text[start_idx:].startswith(self.bot_token):
start_idx = consume_space(start_idx + len(self.bot_token), current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string), current_text)
if (
not current_text
or start_idx >= len(current_text)
or current_text[start_idx] != "["
):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = None
is_complete = None
try:
tool_calls, end_idx = partial_json_loads(
current_text[start_idx:], flags
)
if type(tool_calls) is list:
tool_call_arr = tool_calls
else:
return DeltaMessage(content=delta_text)
is_complete = [True] * len(tool_calls)
if not is_complete_json(current_text[start_idx : start_idx + end_idx]):
is_complete[-1] = False
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if not tool_call_arr:
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id]
delta = None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
if len(tool_call_arr) > self.current_tool_id + 1:
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None

View File

@@ -0,0 +1,495 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_ids = self.model_tokenizer.encode(
self.tool_call_start_token, add_special_tokens=False
)
self.tool_call_end_token_ids = self.model_tokenizer.encode(
self.tool_call_end_token, add_special_tokens=False
)
self.tool_call_start_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_start_token_ids
]
self.tool_call_end_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_end_token_ids
]
self.buffered_delta_text = ""
# Very simple idea: when encountering tokens like <, tool, _call, >,
# <, /, tool, _call, >, store them in a buffer.
# When the last token is encountered, empty the buffer and return it.
# If a token appears in an incorrect sequence while storing in the buffer,
# return the preceding buffer along with the token.
def tool_call_delta_buffer(self, delta_text: str):
# If the sequence of tool_call_start or tool_call_end tokens is not yet
# complete, fill the buffer with the token and return "".
if (
delta_text in self.tool_call_start_token_array
or delta_text in self.tool_call_end_token_array
):
# If delta_text is the last token of tool_call_start_token or
# tool_call_end_token, empty the buffer and return
# the buffered text + delta_text.
if (
delta_text == self.tool_call_start_token_array[-1]
or delta_text == self.tool_call_end_token_array[-1]
):
buffered_text = self.buffered_delta_text
self.buffered_delta_text = ""
return buffered_text + delta_text
else:
self.buffered_delta_text = self.buffered_delta_text + delta_text
return ""
else:
if self.buffered_delta_text:
buffered_text = self.buffered_delta_text
self.buffered_delta_text = ""
return buffered_text + delta_text
else:
return delta_text
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are
# marked "special" in some models. Since they are skipped
# prior to the call to the tool parser, it breaks tool calling.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(model_output)
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
json.loads(match[0] if match[0] else match[1])
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
for function_call in raw_function_calls
]
content = model_output[: model_output.find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# 1. All tokens are parsed based on _text, not token_ids.
# 2. All incoming text data is processed by the tool_call_delta_buffer
# function for buffering before being used for parsing.
delta_text = self.tool_call_delta_buffer(delta_text)
# If the last characters of previous_text
# match self.buffered_delta_text, remove only the matching part.
if (
len(previous_text) >= len(self.buffered_delta_text)
and previous_text[-len(self.buffered_delta_text) :]
== self.buffered_delta_text
):
previous_text = previous_text[: -len(self.buffered_delta_text)]
current_text = previous_text + delta_text
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token not in current_text:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_text.count(self.tool_call_start_token)
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
cur_tool_start_count = current_text.count(self.tool_call_start_token)
cur_tool_end_count = current_text.count(self.tool_call_end_token)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
try:
current_tool_call = (
partial_json_parser.loads(tool_call_portion or "{}", flags)
if tool_call_portion
else None
)
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
except json.decoder.JSONDecodeError:
logger.debug("unable to parse JSON")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: str | None = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
# extract the content after {"name": ..., "arguments":
# directly from tool_call_portion as cur_arguments_json,
# since cur_arguments may differ from the original text
# due to partial JSON parsing
# for example, tool_call_portion =
# {"name": "search", "arguments": {"search_request": {"
# but cur_arguments =
# {"search_request": {}}
function_name = current_tool_call.get("name")
match = re.search(
r'\{"name":\s*"'
+ re.escape(function_name)
+ r'"\s*,\s*"arguments":\s*(.*)',
tool_call_portion.strip(),
re.DOTALL,
)
if match:
cur_arguments_json = match.group(1)
else:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
logger.debug("finding %s in %s", delta_text, cur_arguments_json)
# get the location where previous args differ from current.
if delta_text not in cur_arguments_json:
return None
args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(
delta_text
)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s", arguments_delta)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
# judge whether the tool_call_portion is a complete JSON
try:
json.loads(tool_call_portion)
is_complete_json = True
except Exception:
is_complete_json = False
# if the delta_text ends with a '}' and tool_call_portion is a
# complete JSON, then the last '}' does not belong to the
# arguments, so we should trim it off
if (
isinstance(delta_text, str)
and len(delta_text.rstrip()) >= 1
and delta_text.rstrip()[-1] == "}"
and is_complete_json
):
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=delta_text).model_dump(
exclude_none=True
),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += delta_text
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,420 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501, SIM102
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import consume_space
from vllm.utils import random_uuid
logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode
self.prev_tool_calls: list[dict] = []
self.current_tool_id = -1
self.current_tool_name_sent = False
self.streamed_args: list[str] = [] # Track arguments sent for each tool
# For backward compatibility with tests
self.current_tools_sent: list[bool] = []
# For backward compatibility with serving code
self.prev_tool_call_arr = []
# Regex patterns for preprocessing
self.answer_tool_calls_pattern = re.compile(
r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL
)
self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
self.tool_empty_arg_reg = re.compile(
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
)
# TODO: not support nested json object in fc arguments.
self.tool_non_empty_arg_reg = re.compile(
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
)
self.bot_string = "<tool_calls>"
# Define streaming state type to be initialized later
self.streaming_state: dict[str, Any] = {
"current_tool_index": -1,
"tool_ids": [],
"sent_tools": [],
}
def preprocess_model_output(
self, model_output: str
) -> tuple[str | None, str | None]:
# find the location tool call
for match in self.answer_tool_calls_pattern.finditer(model_output):
start, end = match.span()
# check tool_calls whether in side of <think>
think_regions = [
(m.start(), m.end())
for m in re.finditer(
r"<think>(.*?)</think>", model_output, flags=re.DOTALL
)
]
in_think = any(
start > t_start and end < t_end for t_start, t_end in think_regions
)
if not in_think:
content = model_output[:start]
tool_calls_content = match.group(1).strip()
try:
json.loads(tool_calls_content)
return content, tool_calls_content
except Exception:
continue
return model_output, None
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract tool calls from a complete model output.
"""
try:
# Preprocess the model output
content, potential_tool_calls = self.preprocess_model_output(model_output)
if not potential_tool_calls:
# some text should be filtered out for no function call
# this text is in a13b's chat template.
if content:
content = content.replace("助手:", "", 1)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=content
)
# Parse the potential tool calls as JSON
tool_calls_data = json.loads(potential_tool_calls)
# Ensure it's an array
if not isinstance(tool_calls_data, list):
logger.debug("Tool calls data is not an array")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=content or model_output,
)
tool_calls: list[ToolCall] = []
for idx, call in enumerate(tool_calls_data):
if (
not isinstance(call, dict)
or "name" not in call
or "arguments" not in call
):
continue
tool_call = ToolCall(
id=f"call_{random_uuid()}",
type="function",
function=FunctionCall(
name=call["name"],
arguments=(
json.dumps(call["arguments"])
if isinstance(call["arguments"], dict)
else call["arguments"]
),
),
)
tool_calls.append(tool_call)
if not content or len(content.strip()) == 0:
# clear the whitespace content.
content = None
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=content,
)
except Exception:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""
Extract tool calls for streaming mode.
"""
start_idx = consume_space(0, current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string), current_text)
if (
not current_text
or start_idx >= len(current_text)
or current_text[start_idx] != "["
):
return DeltaMessage(content=delta_text)
self._try_parse_json_tools(current_text[start_idx:])
test_delta = self._handle_test_compatibility(current_text)
if test_delta:
return test_delta
name_matches = list(self.tool_name_reg.finditer(current_text))
tool_count = len(name_matches)
if tool_count == 0:
return None
self._ensure_state_arrays(tool_count)
current_idx = self.streaming_state["current_tool_index"]
name_delta = self._handle_tool_name_streaming(
current_idx, tool_count, name_matches
)
if name_delta:
return name_delta
args_delta = self._handle_tool_args_streaming(
current_text, current_idx, tool_count
)
if args_delta:
return args_delta
return None
def _try_parse_json_tools(self, current_text: str):
try:
parsed_tools = json.loads(current_text)
if isinstance(parsed_tools, list):
self.prev_tool_call_arr = parsed_tools
except json.JSONDecodeError:
pass
def _handle_test_compatibility(self, current_text: str):
if len(self.current_tools_sent) > 0:
if (
len(self.current_tools_sent) == 1
and self.current_tools_sent[0] is False
):
name_match = self.tool_name_reg.search(current_text)
if name_match:
function_name = name_match.group(1)
tool_id = f"chatcmpl-tool-{random_uuid()}"
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tools_sent = [True]
self.current_tool_id = 0
self.streaming_state["current_tool_index"] = 0
if len(self.streaming_state["sent_tools"]) == 0:
self.streaming_state["sent_tools"].append(
{
"sent_name": True,
"sent_arguments_prefix": False,
"sent_arguments": "",
}
)
else:
self.streaming_state["sent_tools"][0]["sent_name"] = True
self.current_tool_name_sent = True
return delta
return None
def _ensure_state_arrays(self, tool_count: int):
while len(self.streaming_state["sent_tools"]) < tool_count:
self.streaming_state["sent_tools"].append(
{
"sent_name": False,
"sent_arguments_prefix": False,
"sent_arguments": "",
}
)
while len(self.streaming_state["tool_ids"]) < tool_count:
self.streaming_state["tool_ids"].append(None)
def _handle_tool_name_streaming(
self, current_idx: int, tool_count: int, name_matches
):
if current_idx == -1 or current_idx < tool_count - 1:
next_idx = current_idx + 1
if (
next_idx < tool_count
and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
):
self.streaming_state["current_tool_index"] = next_idx
self.current_tool_id = next_idx
current_idx = next_idx
tool_name = name_matches[current_idx].group(1)
tool_id = f"call_{current_idx}_{random_uuid()}"
self.streaming_state["tool_ids"][current_idx] = tool_id
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
type="function",
id=tool_id,
function=DeltaFunctionCall(name=tool_name).model_dump(
exclude_none=True
),
)
]
)
self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
self.current_tool_name_sent = True
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
return delta
return None
def _handle_tool_args_streaming(
self, current_text: str, current_idx: int, tool_count: int
):
if current_idx >= 0 and current_idx < tool_count:
empty_args_match = self.tool_empty_arg_reg.search(current_text)
if empty_args_match and empty_args_match.start() > 0:
for i in range(tool_count):
if i == current_idx:
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
]:
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = "{}"
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{}"
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments="{}"
).model_dump(exclude_none=True),
)
]
)
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"
]
return delta
args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
is_last_tool = current_idx == tool_count - 1
if not is_last_tool:
next_tool_pos = current_text.find(
"},{", args_matches[current_idx].start()
)
if next_tool_pos != -1:
args_end_pos = next_tool_pos + 1
args_text = (
current_text[
args_matches[current_idx].start() : args_end_pos
]
.split('"arguments":')[1]
.strip()
)
sent_args = self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
]
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] and args_text.startswith("{"):
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = "{"
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{"
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(arguments="{").model_dump(
exclude_none=True
),
)
]
)
return delta
if args_text.startswith(sent_args):
args_diff = args_text[len(sent_args) :]
if args_diff:
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = args_text
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += args_diff
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments=args_diff
).model_dump(exclude_none=True),
)
]
)
return delta
if args_text.endswith("}") and args_text == sent_args:
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"
]
return None

View File

@@ -0,0 +1,227 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because internlm use the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_arguments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if "<|action_start|>" not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sent, return an empty delta message
# to make sure the finish_reason will be sent correctly.
if self.current_tool_id > 0:
return DeltaMessage(content="")
last_pos = self.position
if "<|action_start|><|plugin|>" not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split("<|action_start|><|plugin|>")
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split("<|action_end|>".strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in internlm2
# it's not support parallel tool calls
try:
tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_arguments(
self.prev_tool_call_arr[self.current_tool_id]
)
cur_arguments = self.get_arguments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
)
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
arguments_delta = cur_arguments_json[
: cur_arguments_json.index(delta_text) + len(delta_text)
]
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_arguments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if "<|action_start|><|plugin|>" in text:
text, action = text.split("<|action_start|><|plugin|>")
action = action.split("<|action_end|>".strip())[0]
action = action[action.find("{") :]
action_dict = json.loads(action)
name, parameters = (
action_dict["name"],
json.dumps(
action_dict.get("parameters", action_dict.get("arguments", {})),
ensure_ascii=False,
),
)
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=text
)
tool_calls = [
ToolCall(function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None,
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=text
)

View File

@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.utils import extract_intermediate_diff
logger = init_logger(__name__)
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
raise ValueError(
"Detected a MistralTokenizer tokenizer when using a Jamba model"
)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool_calls>"
self.tool_calls_end_token: str = "</tool_calls>"
self.tool_calls_regex = re.compile(
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Jamba Tool parser could not locate tool calls start/end "
"tokens in the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because jamba use the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# use a regex to find the tool call between the tags
function_calls = self.tool_calls_regex.findall(model_output)[0]
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = json.loads(function_calls)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
for function_call in raw_function_calls
]
content = model_output[: model_output.find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if (len(content) > 0 and content != " ") else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.tool_calls_start_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the start of tool calls token which means
# the start of tool calling
if (
self.tool_calls_start_token_id in delta_token_ids
and len(delta_token_ids) == 1
):
# if it's the only token, return None, so we don't send a chat
# completion and don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
# Extract the tool calls between the special tool call tokens
parsable_arr = current_text.split(self.tool_calls_start_token)[-1].split(
self.tool_calls_end_token
)[0]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags
)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: str | None = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], ""
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("'", '"')
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
)
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
logger.debug("finding %s in %s", new_text, cur_arguments_json)
arguments_delta = cur_arguments_json[
: cur_arguments_json.index(new_text) + len(new_text)
]
logger.debug(
"First tokens in arguments received: %s", arguments_delta
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
logger.debug(
"Searching for diff between \n%s\n%s",
cur_args_json,
prev_args_json,
)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None

View File

@@ -0,0 +1,590 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code modified from deepseekv3_tool_parser.py
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
# Section-level state management to prevent token leakage
self.in_tool_section: bool = False
self.token_buffer: str = ""
# Buffer size: empirical worst-case for longest marker (~30 chars) * 2
# + safety margin for unicode + partial overlap. Prevents unbounded growth.
self.buffer_max_size: int = 1024
self.section_char_count: int = 0 # Track characters processed in tool section
self.max_section_chars: int = 8192 # Force exit if section exceeds this
self._buffer_overflow_logged: bool = False # Log overflow once per session
# Support both singular and plural variants
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
self.tool_calls_start_token_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>", # singular variant
]
self.tool_calls_end_token_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>", # singular variant
]
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*<\|tool_call_end\|>",
re.DOTALL,
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)"
)
self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*")
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
# Get token IDs for all variants
self.tool_calls_start_token_ids: list[int] = [
tid
for variant in self.tool_calls_start_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_calls_end_token_ids: list[int] = [
tid
for variant in self.tool_calls_end_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Kimi-K2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
"""
Check for section begin/end markers in text and strip them.
Returns: (cleaned_text, found_section_begin, found_section_end)
"""
found_begin = False
found_end = False
cleaned = text
# Check for section begin markers (any variant)
for variant in self.tool_calls_start_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_begin = True
# Check for section end markers (any variant)
for variant in self.tool_calls_end_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_end = True
return cleaned, found_begin, found_end
def _reset_section_state(self) -> None:
"""Reset state when exiting tool section."""
self.in_tool_section = False
self.token_buffer = ""
self.section_char_count = 0
def reset_streaming_state(self) -> None:
"""
Reset all streaming state. Call this between requests to prevent
state leakage when parser instance is reused.
"""
# Reset section state
self._reset_section_state()
# Reset parent class state
self.current_tool_name_sent = False
self.prev_tool_call_arr = []
self.current_tool_id = -1
self.streamed_args_for_tool = []
logger.debug("Streaming state reset")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(model_output)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls = []
for match in function_call_tuples:
function_id, function_args = match
# function_id: functions.get_weather:0 or get_weather:0
function_name = function_id.split(":")[0].split(".")[-1]
tool_calls.append(
ToolCall(
id=function_id,
type="function",
function=FunctionCall(
name=function_name, arguments=function_args
),
)
)
content = model_output[: model_output.find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# Flag to defer section exit until after tool parsing completes
deferred_section_exit = False
# Add delta to buffer for split marker detection
self.token_buffer += delta_text
# Enforce buffer size limit to prevent memory issues
if len(self.token_buffer) > self.buffer_max_size:
if not self._buffer_overflow_logged:
logger.warning(
"Token buffer exceeded max size (%d bytes), flushing excess. "
"This may indicate very long markers or unusual tokenization.",
self.buffer_max_size,
)
self._buffer_overflow_logged = True
# Keep only the most recent content that might contain partial markers
self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
# Check buffer for section markers (handles split tokens)
buffered_text, found_section_begin, found_section_end = (
self._check_and_strip_markers(self.token_buffer)
)
# Track section state transitions
if found_section_begin and not self.in_tool_section:
logger.debug("Entering tool section")
self.in_tool_section = True
self.token_buffer = buffered_text # Use cleaned buffer
self.section_char_count = 0 # Reset counter for new section
if found_section_end and self.in_tool_section:
logger.debug("Detected section end marker")
# CRITICAL: Don't exit early if tool_call_end is in this chunk.
# Tool parser must emit final arguments/close first to avoid dropping
# the final tool update and leaking tokens into reasoning channel.
has_tool_end = self.tool_call_end_token_id in delta_token_ids
if has_tool_end:
# Defer exit until after tool parsing completes
deferred_section_exit = True
logger.debug("Deferring section exit: tool_call_end in same chunk")
self.token_buffer = buffered_text
else:
# No tool call ending, safe to exit immediately
logger.debug("Exiting tool section")
remaining = buffered_text
self._reset_section_state()
# Return remaining text as reasoning content if non-empty
if remaining.strip():
return DeltaMessage(content=remaining)
# Return empty delta to maintain function contract
# (always returns DeltaMessage)
return DeltaMessage(content="")
else:
self.token_buffer = buffered_text
# Check if any variant of section start token is in current_token_ids
has_section_token = any(
tid in current_token_ids for tid in self.tool_calls_start_token_ids
)
# Early return: if no section token detected yet, return as reasoning content
if not has_section_token and not self.in_tool_section:
logger.debug("No tool call tokens found!")
# Don't clear buffer - it needs to accumulate partial markers across deltas
# Buffer overflow is already protected by lines 215-224
return DeltaMessage(content=delta_text)
# Strip section markers from delta_text for subsequent processing
# NOTE: This preprocessing happens BEFORE the regex-based tool call
# parsing (from PR #24847) to ensure markers are removed cleanly
# before pattern matching. No double-stripping occurs because
# section markers and tool call markers are distinct.
delta_text, _, _ = self._check_and_strip_markers(delta_text)
# Error recovery: If in tool section for too long, force exit
if self.in_tool_section:
self.section_char_count += len(delta_text)
if self.section_char_count > self.max_section_chars:
logger.warning(
"Tool section exceeded max length (%d chars), forcing exit. "
"This may indicate malformed model output.",
self.max_section_chars,
)
self._reset_section_state()
# Deferred exit already handled by forced exit above
# Return remaining content as reasoning (or empty delta if no content)
return DeltaMessage(content=delta_text if delta_text.strip() else "")
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id
)
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id
)
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
# CRITICAL FIX: Suppress content if in tool section but
# no tool calls started
if self.in_tool_section and cur_tool_start_count == 0:
logger.debug(
"In tool section but no tool calls started yet. "
"Suppressing: %s",
delta_text,
)
# Return empty delta to maintain iterator contract
return DeltaMessage(content="")
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
# Check if we're in tool section - if so, suppress
if self.in_tool_section:
logger.debug("In tool section, suppressing text generation")
# Handle deferred section exit before returning
if deferred_section_exit:
self._reset_section_state()
return DeltaMessage(content="")
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
tool_call_portion
)
if current_tool_call_matches:
tool_id, tool_args = current_tool_call_matches.groups()
tool_name = tool_id.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(tool_call_portion)
)
if current_tool_call_name_matches:
(tool_id_str,) = current_tool_call_name_matches.groups()
tool_name = tool_id_str.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id_str
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: str | None = current_tool_call.get("name")
tool_id = current_tool_call.get("id")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (
isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)
):
delta_arguments = cur_arguments[len(prev_arguments) :]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
# Handle deferred section exit after tool parsing completes
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,341 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if model_output.startswith("<|python_start|>"):
model_output = model_output[len("<|python_start|>") :]
model_output = model_output.replace("<|python_end|>", "")
is_tool_call_pattern = False
try:
is_tool_call_pattern = (
self.TOOL_CALL_REGEX.match(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
)
is not None
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts
):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not current_text.startswith("[") and not current_text.startswith(
"<|python_start|>"
):
return DeltaMessage(content=delta_text)
try:
# remove <|python_start|> and <|python_end|>
if current_text.startswith("<|python_start|>"):
current_text = current_text[len("<|python_start|>") :]
if current_text.endswith("<|python_end|>"):
current_text = current_text[: current_text.rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = (
index < len(tool_calls) - 1 or ")]" not in added_text
)
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = added_text[:-2] if not new_call_complete else ""
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
if delta is not None:
tool_deltas.append(delta)
if (
delta.function is not None
and delta.function.arguments is not None
):
self.streamed_args_for_tool[index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content="")
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -0,0 +1,324 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
find_common_prefix,
is_complete_json,
partial_json_loads,
)
logger = init_logger(__name__)
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.x and 4 models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
llama4_json are set.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
0
]
# Simple regex to find opening braces - we'll use JSON decoder for parsing
# This handles arbitrary nesting depth correctly
self.tool_call_start_regex = re.compile(r"\{")
self.json_decoder = json.JSONDecoder()
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
Only extracts JSON content and ignores any surrounding plain text.
Supports both single JSON and multiple JSONs separated by semicolons.
"""
# Quick check before running regex
if not (self.bot_token in model_output or "{" in model_output):
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Keep track of the end index of the last parsed JSON object
# so we don't parse inner brackets
end_index = -1
tool_calls: list[ToolCall] = []
try:
for match in self.tool_call_start_regex.finditer(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
):
start_index = match.start()
# Skip if this brace is inside a previously parsed JSON object
if start_index <= end_index:
continue
try:
obj, json_end_index = self.json_decoder.raw_decode(
model_output[start_index:]
)
end_index = start_index + json_end_index
# raise KeyError if missing
name = obj["name"]
arguments_or_params = (
obj["arguments"] if "arguments" in obj else obj["parameters"]
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=name,
# function call args are JSON but as a string
arguments=json.dumps(
arguments_or_params, ensure_ascii=False
),
),
)
)
except KeyError as e:
# Missing required key
missing_key = str(e).strip("'\"")
logger.exception(
"Couldn't extract tool call from JSON response. "
"Required key '%s' not present. "
"Returning output in content with empty tool calls.",
missing_key,
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
except Exception:
# Any other error during parsing
logger.exception(
"Error in extracting tool call from response. "
"Returning output in content with empty tool calls"
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# If we have valid tool calls, return them normally
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
# No valid tool calls found
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not (
current_text.startswith(self.bot_token) or current_text.startswith("{")
):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
is_complete.append(
is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert "arguments" not in obj, (
"model generated both parameters and arguments"
)
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += (
argument_diff
)
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import regex as re
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>"
self.tool_call_end_token: str = "</longcat_tool_call>"
self.tool_call_regex = re.compile(
r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)",
re.DOTALL,
)
self.tool_call_start_token_ids = self.model_tokenizer.encode(
self.tool_call_start_token, add_special_tokens=False
)
self.tool_call_end_token_ids = self.model_tokenizer.encode(
self.tool_call_end_token, add_special_tokens=False
)
self.tool_call_start_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_start_token_ids
]
self.tool_call_end_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_end_token_ids
]

View File

@@ -0,0 +1,643 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
# Sentinel tokens
self.tool_call_start_token: str = "<minimax:tool_call>"
self.tool_call_end_token: str = "</minimax:tool_call>"
self.invoke_start_prefix: str = "<invoke name="
self.invoke_end_token: str = "</invoke>"
self.parameter_prefix: str = "<parameter name="
self.parameter_end_token: str = "</parameter>"
# Streaming state variables
self.current_tool_name_sent: bool = False
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Initialize streaming state variables
self.current_tool_index: int = 0
self.invoke_index: int = 0
self.header_sent: bool = False
self.current_function_name: str | None = None
self.current_param_name: str | None = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.accumulated_text: str = ""
self.json_started: bool = False
self.json_closed: bool = False
self.accumulated_params: dict = {}
self.streaming_request: ChatCompletionRequest | None = None
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
)
self.invoke_complete_regex = re.compile(
r"<invoke name=(.*?)</invoke>", re.DOTALL
)
self.parameter_complete_regex = re.compile(
r"<parameter name=(.*?)</parameter>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"MiniMax M2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger.debug(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.invoke_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
# Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear()
def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string."""
name_str = name_str.strip()
if (
name_str.startswith('"')
and name_str.endswith('"')
or name_str.startswith("'")
and name_str.endswith("'")
):
return name_str[1:-1]
return name_str
def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
try:
return int(value)
except (ValueError, TypeError):
return value
elif param_type in ["number", "float"]:
try:
val = float(value)
return val if val != int(val) else int(val)
except (ValueError, TypeError):
return value
elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
# Try JSON parse first, fallback to string
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _parse_single_invoke(
self, invoke_str: str, tools: list | None
) -> ToolCall | None:
"""Parse a single <invoke> block."""
# Extract function name
name_match = re.search(r"^([^>]+)", invoke_str)
if not name_match:
return None
function_name = self._extract_name(name_match.group(1))
# Get parameter configuration
param_config = {}
if tools:
for tool in tools:
if (
hasattr(tool, "function")
and tool.function.name == function_name
and hasattr(tool.function, "parameters")
):
params = tool.function.parameters
if isinstance(params, dict) and "properties" in params:
param_config = params["properties"]
break
# Extract parameters
param_dict = {}
for match in self.parameter_complete_regex.findall(invoke_str):
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
if param_match:
param_name = self._extract_name(param_match.group(1))
param_value = param_match.group(2).strip()
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Get parameter type
param_type = "string"
if (
param_name in param_config
and isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = param_config[param_name]["type"]
# Convert value
param_dict[param_name] = self._convert_param_value(
param_value, param_type
)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(param_dict, ensure_ascii=False),
),
)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls = []
# Find all complete tool_call blocks
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call
for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
tool_call = self._parse_single_invoke(
invoke_match, request.tools if request else None
)
if tool_call:
tool_calls.append(tool_call)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Update prev_tool_call_arr
self.prev_tool_call_arr.clear()
for tool_call in tool_calls:
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
)
# Extract content before first tool call
first_tool_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output."""
# Store request for type conversion
if not previous_text or self.tool_call_start_token in delta_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
invoke_ends = current_text.count(self.invoke_end_token)
if invoke_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.in_function = False # Now we can safely set this to False
self.accumulated_params = {}
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (
self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text
):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[
: delta_text.index(self.tool_call_start_token)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count = current_text.count(self.invoke_start_prefix)
if self.current_tool_index >= invoke_starts_count:
# We're past all tool calls, shouldn't be here
return None
# Find the current tool call portion
invoke_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.invoke_start_prefix, idx)
if idx == -1:
break
invoke_start_positions.append(idx)
idx += len(self.invoke_start_prefix)
if self.current_tool_index >= len(invoke_start_positions):
# No more tool calls to process yet
return None
invoke_start_idx = invoke_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
if invoke_end_idx == -1:
tool_text = current_text[invoke_start_idx:]
else:
tool_text = current_text[
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
]
# Looking for function header
if not self.header_sent:
if self.invoke_start_prefix in tool_text:
func_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
# Find the end quote for the function name
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
function_name_raw = tool_text[func_start:func_end]
self.current_function_name = self._extract_name(function_name_raw)
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if len(self.prev_tool_call_arr) <= self.current_tool_index:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if self.in_function and not self.json_started:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.invoke_end_token in tool_text:
# Count total parameters in the tool text
total_param_count = tool_text.count(self.parameter_prefix)
# Only close JSON if all parameters have been processed
if self.param_count >= total_param_count:
# Close JSON
self.json_closed = True
# Extract complete tool call
# Find the invoke content
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
invoke_content_end = tool_text.find(
self.invoke_end_token, invoke_start
)
if invoke_content_end != -1:
invoke_content = tool_text[invoke_start:invoke_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_single_invoke(
invoke_content,
self.streaming_request.tools
if self.streaming_request
else None,
)
if parsed_tool and self.current_tool_index < len(
self.prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
args = parsed_tool.function.arguments
self.prev_tool_call_arr[self.current_tool_index][
"arguments"
] = args
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.json_closed = True
self.in_function = False
self.accumulated_params = {}
logger.debug("[M2_STREAMING] Tool call completed")
return result
else:
# Don't close JSON yet, continue processing parameters
return None
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
param_name_raw = remaining[:name_end]
self.current_param_name = self._extract_name(param_name_raw)
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.invoke_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.invoke_end_token in tool_text:
# Tool call and parameter is complete
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = {}
if self.streaming_request and self.streaming_request.tools:
for tool in self.streaming_request.tools:
if (
hasattr(tool, "function")
and tool.function.name == self.current_function_name
and hasattr(tool.function, "parameters")
):
params = tool.function.parameters
if (
isinstance(params, dict)
and "properties" in params
):
param_config = params["properties"]
break
# Get parameter type
param_type = "string"
if (
self.current_param_name in param_config
and isinstance(param_config[self.current_param_name], dict)
and "type" in param_config[self.current_param_name]
):
param_type = param_config[self.current_param_name]["type"]
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value, param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
return None

View File

@@ -0,0 +1,849 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
logger = init_logger(__name__)
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress
self.streaming_state: dict[str, Any] = {
"current_tool_index": -1, # Index of current tool being processed
"tool_ids": [], # List of tool call IDs
"sent_tools": [], # List of tools that have been sent
}
# Define tool call tokens and patterns
self.tool_call_start_token = "<tool_calls>"
self.tool_call_end_token = "</tool_calls>"
self.tool_call_regex = re.compile(
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL
)
self.thinking_tag_pattern = r"<think>(.*?)</think>"
self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
self.tool_args_pattern = re.compile(r'"arguments":\s*')
# Buffer for handling partial tool calls during streaming
self.pending_buffer = ""
self.in_thinking_tag = False
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Get token IDs for tool call start/end tokens
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
logger.warning(
"Minimax Tool parser could not locate tool call start/end "
"tokens in the tokenizer. Falling back to string matching."
)
def preprocess_model_output(self, model_output: str) -> str:
"""
Preprocess model output by removing tool calls from thinking tags.
Args:
model_output: Raw model output string
Returns:
Preprocessed model output with tool calls removed from thinking tags
"""
def remove_tool_calls_from_think(match):
think_content = match.group(1)
cleaned_content = re.sub(
r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL
)
return f"<think>{cleaned_content}</think>"
return re.sub(
self.thinking_tag_pattern,
remove_tool_calls_from_think,
model_output,
flags=re.DOTALL,
)
def _clean_duplicate_braces(self, args_text: str) -> str:
"""
Clean duplicate closing braces from arguments text.
Args:
args_text: Raw arguments text
Returns:
Cleaned arguments text with proper JSON formatting
"""
args_text = args_text.strip()
if not args_text:
return args_text
try:
json.loads(args_text)
return args_text
except json.JSONDecodeError:
pass
while args_text.endswith("}}"):
candidate = args_text[:-1]
try:
json.loads(candidate)
return candidate
except json.JSONDecodeError:
args_text = candidate
return args_text
def _clean_delta_braces(self, delta_text: str) -> str:
"""
Clean delta text by removing excessive closing braces.
Args:
delta_text: Delta text to clean
Returns:
Cleaned delta text
"""
if not delta_text:
return delta_text
delta_stripped = delta_text.strip()
if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped):
brace_count = delta_stripped.count("}")
if brace_count > 1:
return "}\n" if delta_text.endswith("\n") else "}"
return delta_text
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract tool calls from model output for non-streaming mode.
Args:
model_output: Complete model output
request: Chat completion request
Returns:
ExtractedToolCallInformation containing tool calls and content
"""
processed_output = self.preprocess_model_output(model_output)
if self.tool_call_start_token not in processed_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_call_tuples = self.tool_call_regex.findall(processed_output)
raw_function_calls = []
for match in function_call_tuples:
tool_call_content = match[0] if match[0] else match[1]
if tool_call_content.strip():
lines = tool_call_content.strip().split("\n")
for line in lines:
line = line.strip()
if line and line.startswith("{") and line.endswith("}"):
try:
parsed_call = json.loads(line)
raw_function_calls.append(parsed_call)
except json.JSONDecodeError:
continue
tool_calls = []
for function_call in raw_function_calls:
if "name" in function_call and "arguments" in function_call:
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
)
processed_pos = processed_output.find(self.tool_call_start_token)
if processed_pos != -1:
processed_content = processed_output[:processed_pos].strip()
if processed_content:
lines = processed_content.split("\n")
for line in reversed(lines):
line = line.strip()
if line:
pos = model_output.find(line)
if pos != -1:
content = model_output[: pos + len(line)]
break
else:
content = ""
else:
content = ""
else:
content = model_output
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=content.strip() if content.strip() else None,
)
except Exception:
logger.exception(
"An unexpected error occurred during tool call extraction."
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _update_thinking_state(self, text: str) -> None:
"""
Update the thinking tag state based on text content.
Args:
text: Text to analyze for thinking tags
"""
open_count = text.count("<think>")
close_count = text.count("</think>")
self.in_thinking_tag = open_count > close_count or (
open_count == close_count and text.endswith("</think>")
)
def _is_potential_tag_start(self, text: str) -> bool:
"""
Check if text might be the start of a tool call tag.
Args:
text: Text to check
Returns:
True if text could be the start of a tool call tag
"""
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
if any(
tag.startswith(text[-i:])
for i in range(1, min(len(text) + 1, len(tag)))
):
return True
return False
def _should_buffer_content(self, delta_text: str) -> bool:
"""
Determine if content should be buffered for later processing.
Args:
delta_text: Delta text to check
Returns:
True if content should be buffered
"""
if self.in_thinking_tag:
return False
return bool(
self.pending_buffer
or self.tool_call_start_token in delta_text
or self.tool_call_end_token in delta_text
or delta_text.startswith("<")
)
def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
"""
Split delta text into safe content and potential tag content.
Args:
delta_text: Delta text to split
Returns:
Tuple of (safe_content, potential_tag_content)
"""
if self.in_thinking_tag:
return delta_text, ""
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
for i in range(1, len(tag)):
tag_prefix = tag[:i]
pos = delta_text.rfind(tag_prefix)
if pos != -1 and tag.startswith(delta_text[pos:]):
return delta_text[:pos], delta_text[pos:]
return delta_text, ""
def _process_buffer(self, new_content: str) -> str:
"""
Process buffered content and return output content.
Args:
new_content: New content to add to buffer
Returns:
Processed output content
"""
self.pending_buffer += new_content
output_content = ""
if self.in_thinking_tag:
output_content = self.pending_buffer
self.pending_buffer = ""
return output_content
while self.pending_buffer:
start_pos = self.pending_buffer.find(self.tool_call_start_token)
end_pos = self.pending_buffer.find(self.tool_call_end_token)
if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
elif end_pos != -1:
tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
else:
if self._is_potential_tag_start(self.pending_buffer):
break
output_content += self.pending_buffer
self.pending_buffer = ""
break
output_content += self.pending_buffer[:tag_pos]
self.pending_buffer = self.pending_buffer[tag_pos + tag_len :]
return output_content
def _reset_streaming_state(self) -> None:
"""Reset the streaming state to initial values."""
self.streaming_state = {
"current_tool_index": -1,
"tool_ids": [],
"sent_tools": [],
}
def _advance_to_next_tool(self) -> None:
"""Advance to the next tool in the streaming sequence."""
self.streaming_state["current_tool_index"] = (
int(self.streaming_state["current_tool_index"]) + 1
)
def _set_current_tool_index(self, index: int) -> None:
"""
Set the current tool index.
Args:
index: Tool index to set
"""
self.streaming_state["current_tool_index"] = index
def _get_current_tool_index(self) -> int:
"""
Get the current tool index.
Returns:
Current tool index
"""
return int(self.streaming_state["current_tool_index"])
def _get_next_unsent_tool_index(self, tool_count: int) -> int:
"""
Get the index of the next unsent tool.
Args:
tool_count: Total number of tools
Returns:
Index of next unsent tool, or -1 if all tools sent
"""
sent_tools = list(self.streaming_state["sent_tools"])
for i in range(tool_count):
if i < len(sent_tools):
if not sent_tools[i]["sent_name"]:
return i
else:
return i
return -1
def _ensure_state_arrays(self, tool_count: int) -> None:
"""
Ensure state arrays have sufficient capacity for tool_count tools.
Args:
tool_count: Number of tools to prepare for
"""
sent_tools = list(self.streaming_state["sent_tools"])
tool_ids = list(self.streaming_state["tool_ids"])
while len(sent_tools) < tool_count:
sent_tools.append(
{
"sent_name": False,
"sent_arguments": "",
"id": make_tool_call_id(),
}
)
while len(tool_ids) < tool_count:
tool_ids.append(None)
self.streaming_state["sent_tools"] = sent_tools
self.streaming_state["tool_ids"] = tool_ids
def _detect_tools_in_text(self, text: str) -> int:
"""
Detect the number of tools in text by counting name patterns.
Args:
text: Text to analyze
Returns:
Number of tools detected
"""
matches = self.tool_name_pattern.findall(text)
return len(matches)
def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
"""
Find the boundaries of tool calls in text.
Args:
text: Text to analyze
Returns:
List of (start, end) positions for tool calls
"""
boundaries = []
i = 0
while i < len(text):
if text[i] == "{":
start = i
depth = 0
has_name = False
has_arguments = False
while i < len(text):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
end = i + 1
segment = text[start:end]
if '"name"' in segment and '"arguments"' in segment:
boundaries.append((start, end))
break
if not has_name and '"name"' in text[start : i + 1]:
has_name = True
if not has_arguments and '"arguments"' in text[start : i + 1]:
has_arguments = True
i += 1
if depth > 0 and has_name:
boundaries.append((start, i))
else:
i += 1
return boundaries
def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str:
"""
Extract tool arguments from tool content.
Args:
tool_content: Tool call content
args_match: Regex match for arguments pattern
Returns:
Extracted arguments as string
"""
args_start_pos = args_match.end()
remaining_content = tool_content[args_start_pos:]
if remaining_content.strip().startswith("{"):
depth = 0
for i, char in enumerate(remaining_content):
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return remaining_content[: i + 1]
else:
args_end = remaining_content.find("}")
if args_end > 0:
return remaining_content[:args_end].strip()
return remaining_content.rstrip("}").strip()
def _get_current_tool_content(
self, text: str, tool_index: int
) -> tuple[str | None, str | None]:
"""
Get the content of a specific tool by index.
Args:
text: Text containing tool calls
tool_index: Index of tool to extract
Returns:
Tuple of (tool_name, tool_arguments) or (None, None) if not found
"""
boundaries = self._find_tool_boundaries(text)
if tool_index >= len(boundaries):
return None, None
start, end = boundaries[tool_index]
tool_content = text[start:end]
name_match = self.tool_name_pattern.search(tool_content)
name = name_match.group(1) if name_match else None
args_match = self.tool_args_pattern.search(tool_content)
if args_match:
try:
args_text = self._extract_tool_args(tool_content, args_match)
return name, args_text
except Exception:
remaining_content = tool_content[args_match.end() :]
args_text = remaining_content.rstrip("}").strip()
return name, args_text
return name, None
def _handle_tool_name_streaming(
self, tool_content: str, tool_count: int
) -> DeltaMessage | None:
"""
Handle streaming of tool names.
Args:
tool_content: Content containing tool calls
tool_count: Total number of tools
Returns:
DeltaMessage with tool name or None if no tool to stream
"""
next_idx = self._get_next_unsent_tool_index(tool_count)
if next_idx == -1:
return None
boundaries = self._find_tool_boundaries(tool_content)
if next_idx >= len(boundaries):
return None
tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
if not tool_name:
return None
self._set_current_tool_index(next_idx)
sent_tools = list(self.streaming_state["sent_tools"])
tool_ids = list(self.streaming_state["tool_ids"])
tool_id = sent_tools[next_idx]["id"]
tool_ids[next_idx] = tool_id
sent_tools[next_idx]["sent_name"] = True
self.streaming_state["sent_tools"] = sent_tools
self.streaming_state["tool_ids"] = tool_ids
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=next_idx,
type="function",
id=tool_id,
function=DeltaFunctionCall(name=tool_name).model_dump(
exclude_none=True
),
)
]
)
def _handle_tool_args_streaming(
self, tool_content: str, tool_count: int
) -> DeltaMessage | None:
"""
Handle streaming of tool arguments.
Args:
tool_content: Content containing tool calls
tool_count: Total number of tools
Returns:
DeltaMessage with tool arguments or None if no arguments to stream
"""
current_idx = self._get_current_tool_index()
if current_idx < 0 or current_idx >= tool_count:
return None
tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx)
if not tool_name or tool_args is None:
return None
sent_tools = list(self.streaming_state["sent_tools"])
if not sent_tools[current_idx]["sent_name"]:
return None
clean_args = self._clean_duplicate_braces(tool_args)
sent_args = sent_tools[current_idx]["sent_arguments"]
if clean_args != sent_args:
if sent_args and clean_args.startswith(sent_args):
args_delta = extract_intermediate_diff(clean_args, sent_args)
if args_delta:
args_delta = self._clean_delta_braces(args_delta)
sent_tools[current_idx]["sent_arguments"] = clean_args
self.streaming_state["sent_tools"] = sent_tools
if clean_args.endswith("}"):
self._advance_to_next_tool()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments=args_delta
).model_dump(exclude_none=True),
)
]
)
elif not sent_args and clean_args:
clean_args_delta = self._clean_delta_braces(clean_args)
sent_tools[current_idx]["sent_arguments"] = clean_args
self.streaming_state["sent_tools"] = sent_tools
if clean_args.endswith("}"):
self._advance_to_next_tool()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments=clean_args_delta
).model_dump(exclude_none=True),
)
]
)
return None
def _is_end_tool_calls(self, current_text: str) -> bool:
if self.tool_call_end_token not in current_text:
return False
end_token_positions = []
search_start = 0
while True:
pos = current_text.find(self.tool_call_end_token, search_start)
if pos == -1:
break
end_token_positions.append(pos)
search_start = pos + 1
think_regions = []
for match in re.finditer(
self.thinking_tag_pattern, current_text, flags=re.DOTALL
):
think_regions.append((match.start(), match.end()))
for pos in end_token_positions:
in_think = any(
pos >= t_start and pos < t_end for t_start, t_end in think_regions
)
if not in_think:
return True
return False
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
self._update_thinking_state(current_text)
if self.in_thinking_tag:
return DeltaMessage(content=delta_text)
if self._should_buffer_content(delta_text):
buffered_output = self._process_buffer(delta_text)
return DeltaMessage(content=buffered_output) if buffered_output else None
if self._is_end_tool_calls(current_text):
return DeltaMessage(content=delta_text)
safe_content, potential_tag = self._split_content_for_buffering(delta_text)
if potential_tag:
self.pending_buffer += potential_tag
return DeltaMessage(content=safe_content) if safe_content else None
processed_current_text = self.preprocess_model_output(current_text)
if self.tool_call_start_token not in processed_current_text:
if (
self.tool_call_end_token in delta_text
and self.tool_call_start_token in current_text
):
return None
if delta_text.strip() == "" and self.tool_call_start_token in current_text:
return None
if (
self._get_current_tool_index() != -1
and self.tool_call_end_token in current_text
):
self._reset_streaming_state()
return DeltaMessage(content=delta_text)
if (
self.tool_call_start_token_id is not None
and self.tool_call_start_token_id in delta_token_ids
and len(delta_token_ids) == 1
):
return None
original_tool_start = self._find_tool_start_outside_thinking(current_text)
if original_tool_start is None:
return None
content_before_tools = self._extract_content_before_tools(
current_text, delta_text, original_tool_start
)
if content_before_tools:
return DeltaMessage(content=content_before_tools)
try:
tool_content = self._extract_tool_content(current_text, original_tool_start)
current_tools_count = self._detect_tools_in_text(tool_content)
if current_tools_count == 0:
return None
if self._get_current_tool_index() == -1:
self._reset_streaming_state()
self._ensure_state_arrays(current_tools_count)
return self._handle_tool_name_streaming(
tool_content, current_tools_count
) or self._handle_tool_args_streaming(tool_content, current_tools_count)
except Exception:
logger.exception(
"An unexpected error occurred ", "during streaming tool call handling."
)
return None
def _find_tool_start_outside_thinking(self, current_text: str) -> int | None:
"""
Find the start position of tool calls outside of thinking tags.
Args:
current_text: Current text to search
Returns:
Position of tool call start or None if not found
"""
search_start = 0
while True:
pos = current_text.find(self.tool_call_start_token, search_start)
if pos == -1:
return None
think_regions = [
(m.start(), m.end())
for m in re.finditer(
r"<think>(.*?)</think>", current_text, flags=re.DOTALL
)
]
in_think = any(
pos >= t_start and pos < t_end for t_start, t_end in think_regions
)
if not in_think:
return pos
search_start = pos + 1
def _extract_content_before_tools(
self, current_text: str, delta_text: str, tool_start: int
) -> str | None:
"""
Extract content that appears before tool calls.
Args:
current_text: Current text
delta_text: Delta text
tool_start: Start position of tools
Returns:
Content before tools or None
"""
if tool_start > 0:
delta_start_pos = len(current_text) - len(delta_text)
if delta_start_pos < tool_start:
content_part = delta_text
if delta_start_pos + len(delta_text) > tool_start:
content_part = delta_text[: tool_start - delta_start_pos]
return content_part if content_part else None
return None
def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
"""
Extract tool content from current text starting at tool_start.
Args:
current_text: Current text
tool_start: Start position of tool calls
Returns:
Extracted tool content
"""
tool_content_start = tool_start + len(self.tool_call_start_token)
tool_content = current_text[tool_content_start:]
end_pos = tool_content.find(self.tool_call_end_token)
if end_pos != -1:
tool_content = tool_content[:end_pos]
return tool_content

View File

@@ -0,0 +1,585 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from enum import Enum, auto
from random import choices
from string import ascii_letters, digits
from typing import Any
import ijson
import regex as re
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class StreamingState(Enum):
"""Enum for tracking the current streaming parsing state."""
WAITING_FOR_TOOL_START = auto()
WAITING_FOR_TOOL_KEY = (
auto()
) # waiting for the "name" or "arguments" key to be complete
PARSING_NAME = auto()
PARSING_NAME_COMPLETED = auto()
WAITING_FOR_ARGUMENTS_START = auto()
PARSING_ARGUMENTS = auto()
PARSING_ARGUMENTS_COMPLETED = auto()
TOOL_COMPLETE = auto()
ALL_TOOLS_COMPLETE = auto()
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START
# For streaming pre v11 tokenizer tool calls
self.current_tool_name: str | None = None
self.current_tool_mistral_id: str | None = None
self.starting_new_tool = False
if _is_pre_v11_tokeniser(self.model_tokenizer):
self.parse_coro = ijson.parse_coro(
self.update_stream_state_pre_v11_tokenizer()
)
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if (
not isinstance(self.model_tokenizer, MistralTokenizer)
and request.tools
and request.tool_choice != "none"
):
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
try:
try:
if not self._is_pre_v11:
function_call_arr = []
for single_tool_content in model_output.split(self.bot_token):
if "{" not in single_tool_content:
continue
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
),
),
)
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=tool_content
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if self.bot_token_id not in current_token_ids:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
# if the tool call token IS in the tokens generated so far, that
# means we're parsing as tool calls now
try:
if _is_pre_v11_tokeniser(self.model_tokenizer):
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
delta_text=delta_text,
delta_token_ids=delta_token_ids,
)
else:
return self._extract_tool_calls_streaming(
delta_text=delta_text, delta_token_ids=delta_token_ids
)
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None
def _extract_tool_calls_streaming(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
"""
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
delta_text.split(self.bot_token)[1:]
)
delta_tool_calls = self._generate_delta_tool_call(delta_text)
if not additional_content and len(delta_tool_calls) == 0:
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.TOOL_COMPLETE,
StreamingState.ALL_TOOLS_COMPLETE,
]:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage()
else:
# return None when the tool is not likely to be finished
# This can occur when the name is being parsed for example
# and we wait for the name to be complete
# before sending the function name
return None
delta = DeltaMessage()
if additional_content:
delta.content = additional_content
if len(delta_tool_calls) > 0:
delta.tool_calls = delta_tool_calls
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
return delta
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
if delta_text == "" or delta_text is None:
return []
delta_function_name = None
tool_id = None
if self.streaming_state not in [
StreamingState.PARSING_NAME,
StreamingState.PARSING_ARGUMENTS,
] and delta_text.startswith(self.bot_token):
self.current_tool_id += 1
self.streaming_state = StreamingState.PARSING_NAME
delta_text = delta_text.replace(self.bot_token, "", 1)
if self.streaming_state == StreamingState.PARSING_NAME:
if self.current_tool_name is None:
self.current_tool_name = ""
# The name stops where the arguments start
# And the arguments start with the `{` char
if "{" in delta_text:
tool_id = MistralToolCall.generate_random_id()
delta_function_name = delta_text.split("{")[0]
self.current_tool_name += delta_function_name
delta_text = delta_text[len(delta_function_name) :]
self.streaming_state = StreamingState.PARSING_ARGUMENTS
else:
# we want to send the tool name once it's complete
self.current_tool_name += delta_text
return []
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
next_function_text = None
if self.bot_token in delta_text:
# current tool call is over
delta_arguments = ""
delta_arguments += delta_text.split(self.bot_token)[0]
next_function_text = delta_text[len(delta_arguments) :]
self.streaming_state = StreamingState.TOOL_COMPLETE
else:
delta_arguments = delta_text
ret = []
if self.current_tool_name or delta_arguments:
ret += [
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=self.current_tool_name, arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
self.current_tool_name = None
if next_function_text:
ret += self._generate_delta_tool_call(next_function_text)
return ret
# Should not happen
return []
@ijson.coroutine
def update_stream_state_pre_v11_tokenizer(self):
while True:
(prefix, event, value) = yield
if prefix == "item" and event == "start_map":
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if prefix == "item" and event == "map_key" and value == "name":
self.streaming_state = StreamingState.PARSING_NAME
if prefix == "item.name" and event == "string":
self.current_tool_name = value
self.streaming_state = StreamingState.PARSING_NAME_COMPLETED
if prefix == "item" and event == "map_key" and value == "arguments":
self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START
if prefix == "item.arguments" and event == "start_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS
if prefix == "item.arguments" and event == "end_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED
if prefix == "item" and event == "end_map":
self.streaming_state = StreamingState.TOOL_COMPLETE
if prefix == "" and event == "end_array":
self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE
def _extract_tool_calls_streaming_pre_v11_tokenizer(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
"""
assert self.parse_coro is not None
content = None
delta_tool_calls: list[DeltaToolCall] = []
current_tool_call: DeltaToolCall = DeltaToolCall(
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]
delta_text = "".join(delta_text.split(self.bot_token)[1:])
# Cut smartly the delta text to catch the ijson events
# as ijson does not give us the index in the text at each event.
# We need to cut so that we know
# where in the text the events are emitted from.
while len(delta_text) > 0:
streaming_state_before_parse = self.streaming_state
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY:
# Wait until another key is sent
# or the current tool is completed
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_colon=1,
stop_after_opening_curly_braces=1,
# if the tool ends, we want to separate
# at the start of the next tool
)
elif self.streaming_state == StreamingState.PARSING_NAME:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_comma=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
elif self.streaming_state == StreamingState.PARSING_ARGUMENTS:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
# we could be more clever
# by listening to item.arguments.* start_map events
# and know how many curly braces we can allow
)
elif self.streaming_state in [
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.PARSING_NAME_COMPLETED,
]:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.TOOL_COMPLETE:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
content = delta_text
delta_text = ""
else:
delta_to_be_parsed = delta_text
delta_text = ""
if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE:
self.parse_coro.send(delta_to_be_parsed.encode("utf-8"))
# Given the parsed text and the possible streaming state change,
# let's add to the tool delta
if (
(streaming_state_before_parse != self.streaming_state)
and streaming_state_before_parse
in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE]
and self.streaming_state
not in [
StreamingState.ALL_TOOLS_COMPLETE,
StreamingState.TOOL_COMPLETE,
StreamingState.WAITING_FOR_TOOL_START,
]
):
# starting a new tool call
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
current_tool_call_modified = False
self.current_tool_id += 1
self.current_tool_mistral_id = MistralToolCall.generate_random_id()
current_tool_call = DeltaToolCall(
index=self.current_tool_id,
type="function",
)
if current_tool_call.function is None:
current_tool_call.function = DeltaFunctionCall()
if self.current_tool_name is not None:
# we have the complete tool name
current_tool_call_modified = True
current_tool_call.function.name = self.current_tool_name
self.current_tool_name = None
if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
]:
if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
# the delta_to_be_parsed is part of arguments.
current_tool_call_modified = True
if current_tool_call.function.arguments is None:
current_tool_call.function.arguments = delta_to_be_parsed
else:
current_tool_call.function.arguments += delta_to_be_parsed
if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS:
# It's the first chunk of arg. let's lstrip it
current_tool_call.function.arguments = (
current_tool_call.function.arguments.lstrip()
)
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if content or len(delta_tool_calls) > 0:
delta_message = DeltaMessage()
if content:
delta_message.content = content
if len(delta_tool_calls) > 0:
delta_message.tool_calls = delta_tool_calls
return delta_message
else:
if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
return DeltaMessage()
else:
return None
def _split_delta(
self,
delta_text: str,
stop_after_quotes: int = -1,
stop_after_opening_curly_braces: int = -1,
stop_after_closing_curly_braces: int = -1,
stop_after_closing_brackets: int = -1,
stop_after_colon: int = -1,
stop_after_comma=-1,
) -> tuple[str, str]:
delta_to_be_parsed = ""
for i, c in enumerate(delta_text):
if c in ['"', "'"]:
delta_to_be_parsed += c
stop_after_quotes -= 1
if stop_after_quotes == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "{":
delta_to_be_parsed += c
stop_after_opening_curly_braces -= 1
if stop_after_opening_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "}":
delta_to_be_parsed += c
stop_after_closing_curly_braces -= 1
if stop_after_closing_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "]":
delta_to_be_parsed += c
stop_after_closing_brackets -= 1
if stop_after_closing_brackets == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ":":
delta_to_be_parsed += c
stop_after_colon -= 1
if stop_after_colon == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ",":
delta_to_be_parsed += c
stop_after_comma -= 1
if stop_after_comma == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
else:
delta_to_be_parsed += c
return (delta_to_be_parsed, "")

View File

@@ -0,0 +1,366 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Olmo3PythonicToolParser(ToolParser):
"""
Tool call parser for Olmo 3 models that produce tool calls as
newline-separated pythonic strings.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
Code copied from pythonic_tool_parser.py and updated to handle
- newline separated pythonic tool calls.
- argument values being null/true/false instead of Pythonic literals.
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
original_model_output = model_output
# Remove xml tags.
match = re.search(
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
)
if match:
model_output = match.group(1).strip()
# Make the newline separated function calls into a list.
model_output = ", ".join(
[line.strip() for line in model_output.splitlines() if line.strip()]
)
model_output = f"[{model_output}]"
is_tool_call_pattern = False
try:
is_tool_call_pattern = (
self.TOOL_CALL_REGEX.match(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
)
is not None
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts
):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# All function calls start with the <function_calls> tag.
# But since this is streaming, we may have seen only part of the tag.
if not current_text.startswith("<"):
return DeltaMessage(content=delta_text)
try:
# Remove xml tags.
if current_text.startswith("<function_calls>"):
current_text = current_text[len("<function_calls>") :]
if current_text.endswith("</function_calls>"):
current_text = current_text[: -len("</function_calls>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
# Make the newline separated function calls into a list.
valid_text = ", ".join(
[line.strip() for line in valid_text.splitlines() if line.strip()]
)
valid_text = f"[{valid_text}]"
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a sequence of newline-separated calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = added_text[:-1] if not new_call_complete else ""
if not new_call_complete and added_text[-1] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
if delta is not None:
tool_deltas.append(delta)
if (
delta.function is not None
and delta.function.arguments is not None
):
self.streamed_args_for_tool[index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content="")
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
# The model may return function calls where the values are null/true/false
# because the system prompt has API description in json.
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
if val.id == "null":
return None
elif val.id == "true":
return True
elif val.id == "false":
return False
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
if TYPE_CHECKING:
from vllm.tokenizers import TokenizerLike
else:
TokenizerLike = object
logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
token_ids: Sequence[int] | None = None,
) -> ExtractedToolCallInformation:
if token_ids is None:
raise NotImplementedError(
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
)
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
commentary_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
if len(msg.content) < 1:
continue
msg_text = msg.content[0].text
if msg.recipient and msg.recipient.startswith("functions."):
# If no content-type is given assume JSON, as that's the
# most common case with gpt-oss models.
if not msg.content_type or "json" in msg.content_type:
# load and dump the JSON text to check validity and
# remove any extra newlines or other odd formatting
try:
tool_args = json.dumps(json.loads(msg_text))
except json.JSONDecodeError:
logger.exception(
"Error decoding JSON tool call from response."
)
tool_args = msg_text
else:
tool_args = msg_text
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=tool_args,
),
)
)
elif msg.channel == "final":
final_content = msg_text
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
# prefer final content over commentary content if both are present
# commentary content is tool call preambles meant to be shown to the user
content=final_content or commentary_content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError(
"Not being used, manual parsing in serving_chat.py" # noqa: E501
)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Phi4MiniJsonToolParser(ToolParser):
"""
Tool call parser for phi-4-mini models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.bot_token: str = "functools"
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
logger.debug("Model output: %s", model_output)
pattern = r"functools\[(.*?)\]"
matches = re.search(pattern, model_output, re.DOTALL)
if not matches:
logger.debug("No function calls found")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_call_arr: list[dict[str, Any]] = []
try:
json_content = "[" + matches.group(1) + "]"
function_call_arr = json.loads(json_content)
logger.debug(
"Successfully extracted %d function calls", len(function_call_arr)
)
except json.JSONDecodeError as e:
logger.error(
"Failed to parse function calls from model output. Error: %s",
str(e),
)
tool_calls: list[ToolCall] = [
ToolCall(
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"]
if "arguments" in raw_function_call
else raw_function_call["parameters"],
ensure_ascii=False,
),
),
)
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
return ret
except Exception:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
return None

View File

@@ -0,0 +1,332 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
is_tool_call_pattern = False
try:
is_tool_call_pattern = (
self.TOOL_CALL_REGEX.match(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
)
is not None
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts
):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not current_text.startswith("["):
return DeltaMessage(content=delta_text)
try:
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = (
index < len(tool_calls) - 1 or ")]" not in added_text
)
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = added_text[:-2] if not new_call_complete else ""
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
if delta is not None:
tool_deltas.append(delta)
if (
delta.function is not None
and delta.function.arguments is not None
):
self.streamed_args_for_tool[index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content="")
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -0,0 +1,781 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
# Sentinel tokens for streaming mode
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns
self.tool_call_complete_regex = re.compile(
r"<tool_call>(.*?)</tool_call>", re.DOTALL
)
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Qwen3 XML Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.debug(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
param_name,
func_name,
)
return param_value
if (
isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
return int(param_value)
except (ValueError, TypeError):
logger.debug(
"Parsed value '%s' of parameter '%s' is not an "
"integer in tool '%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
logger.debug(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.debug(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` or `false`) in tool '%s', degenerating to "
"false.",
param_value,
param_name,
func_name,
)
return param_value == "true"
else:
if (
param_type in ["object", "array", "arr"]
or param_type.startswith("dict")
or param_type.startswith("list")
):
try:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
logger.debug(
"Parsed value '%s' of parameter '%s' cannot be "
"parsed with json.loads in tool '%s', will try "
"other methods to parse it.",
param_value,
param_name,
func_name,
)
try:
param_value = ast.literal_eval(param_value) # safer
except (ValueError, SyntaxError, TypeError):
logger.debug(
"Parsed value '%s' of parameter '%s' cannot be "
"converted via Python `ast.literal_eval()` in tool "
"'%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
for match_text in self.tool_call_parameter_regex.findall(parameters):
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
)
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
idx = model_output.find(self.tool_call_prefix)
content_index = content_index if content_index >= 0 else idx
content = model_output[:content_index] # .rstrip()
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# Store request for type conversion
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# Check for tool calls in text even if is_tool_call_started
# is False (might have been reset after processing all tools)
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated
# prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.accumulated_params = {}
# Check if there are more tool calls
tool_starts = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts:
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (
self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text
):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[
: delta_text.index(self.tool_call_start_token)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
tool_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_start_positions.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_start_positions):
# No more tool calls to process yet
return None
tool_start_idx = tool_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix
)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when
# we detect a tool call. This ensures
# finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr
)
if not already_added:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if not self.json_started and self.parameter_prefix not in delta_text:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract complete tool call to update
# prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix
)
func_content_end = tool_text.find(self.function_end_token, func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content,
self.streaming_request.tools
if self.streaming_request
else None,
)
if parsed_tool:
# Update existing entry in
# prev_tool_call_arr with complete args
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get("name") == parsed_tool.function.name:
args = parsed_tool.function.arguments
self.prev_tool_call_arr[i]["arguments"] = args
break
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
return result
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or
# function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.tool_call_end_token in tool_text:
# Tool call is complete, so parameter
# must be complete too. Use all
# remaining text before function end
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request
else None,
)
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value,
self.current_param_name,
param_config,
self.current_function_name or "",
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
# Store complete value
full_value = self.current_param_value + value_chunk
self.accumulated_params[self.current_param_name] = full_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request
else None,
)
# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
full_value,
self.current_param_name or "",
param_config,
self.current_function_name or "",
)
# Serialize the converted value
serialized_value = json.dumps(converted_value, ensure_ascii=False)
# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""
# Just close the current parameter string
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments='"'
), # Close the string quote
)
]
)
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (
json.dumps(self.current_param_value, ensure_ascii=False)[
1:-1
]
if self.current_param_value
else ""
)
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]
if delta_escaped:
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped
),
)
]
)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,744 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from qwen3coder xml parser, All rights reserved.
# ruff: noqa: E501
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# --- streaming state ---
self._reset_streaming_state()
self.prev_tool_call_arr: list[dict] = []
self.tool_call_start_token: str = self.TOOL_CALL_START
self.tool_call_end_token: str = self.TOOL_CALL_END
# Sentinel tokens for streaming mode
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.think_start_token: str = "<seed:think>"
self.think_end_token: str = "</seed:think>"
self.is_tool_call_started: bool = False
self.is_thinking_end: bool = False
self.failed_count: int = 0
self._reset_streaming_state()
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Seed_Oss XML parser: tokenizer did not include "
"<seed:tool_call> or its closing tag."
)
tool_start_re = re.escape(self.tool_call_start_token)
tool_end_re = re.escape(self.tool_call_end_token)
self.tool_call_complete_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL
)
self.tool_call_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
)
logger.info(
"vLLM Seed-Oss XML tool parser loaded (%s).", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = -1
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
return {}
def convert_param_value(
param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in "
"the tool parameters for tool '%s', "
"directly returning the string value.",
param_name,
func_name,
)
return param_value
if (
isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
param_value = int(param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer in tool "
"'%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
param_value = (
float_param_value # type: ignore
if float_param_value - int(float_param_value) != 0
else int(float_param_value) # type: ignore
)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float in tool "
"'%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` of `false`) in tool '%s', degenerating to false.",
param_value,
param_name,
func_name,
)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
try:
param_value = json.loads(param_value)
return param_value
except (ValueError, TypeError, json.JSONDecodeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a valid JSON "
"object in tool '%s', will try other methods to parse it.",
param_value,
param_name,
func_name,
)
try:
param_value = ast.literal_eval(param_value)
except (ValueError, SyntaxError):
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
param_value,
param_name,
func_name,
)
return param_value
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_value, param_name, param_config, function_name
)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Check if both think start and end tokens are present
if (
self.think_start_token in model_output
and self.think_end_token in model_output
):
# Find the position of think end token
think_end_index = model_output.find(self.think_end_token) + len(
self.think_end_token
)
# Extract content after think end token
result_content = model_output[think_end_index:]
thinking_content = model_output[:think_end_index]
else:
thinking_content = ""
result_content = model_output
try:
function_calls = self._get_function_calls(result_content)
if len(function_calls) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
)
# Extract content before tool calls
tool_call_start_index = result_content.find(self.tool_call_start_token)
tool_call_start_index = (
tool_call_start_index
if tool_call_start_index >= 0
else result_content.find(self.tool_call_prefix)
)
content = thinking_content + result_content[:tool_call_start_index]
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# If no delta text, return None unless
# it's an EOS token after tool calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all tools
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
# Check if there are more tool calls
if self.current_tool_index >= current_text.count(
self.tool_call_start_token
):
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Check if end thinking
if not self.is_thinking_end and (
self.think_end_token_id in delta_token_ids
or self.think_end_token in delta_text
):
self.is_thinking_end = True
# If thinking hasn't ended yet, don't process any tool calls
if not self.is_thinking_end:
return DeltaMessage(content=delta_text)
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (
self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text
):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[
: delta_text.index(self.tool_call_start_token)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
# Only process tool calls after think_end_token
think_end_index = (
current_text.find(self.think_end_token) + len(self.think_end_token)
if self.think_end_token in current_text
else 0
)
tool_starts: list[int] = []
idx = think_end_index
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix
)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id() # type: ignore
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr
)
if not already_added:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if not self.json_started and self.parameter_prefix not in delta_text:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix
)
func_content_end = tool_text.find(self.function_end_token, func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None
)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get("name") == parsed_tool.function.name:
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments
)
break
except Exception:
logger.warning(
"Failed to parse tool arguments during streaming.",
exc_info=True,
)
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.in_function = False
self.json_closed = True
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
if self.param_count == 0:
json_fragment = (
'"'
+ self.current_param_name
+ '": "'
+ json.dumps(param_value)[1:-1]
+ '"'
)
else:
json_fragment = (
', "'
+ self.current_param_name
+ '": "'
+ json.dumps(param_value)[1:-1]
+ '"'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment
),
)
]
)
# Continue parameter value
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
full_value = self.current_param_value + value_chunk
prev_escaped = (
json.dumps(self.current_param_value)[1:-1]
if self.current_param_value
else ""
)
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]
self.in_param = False
self.current_param_value = ""
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'
),
)
]
)
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (
json.dumps(self.current_param_value)[1:-1]
if self.current_param_value
else ""
)
self.current_param_value += value_chunk
full_escaped = json.dumps(self.current_param_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]
if delta_escaped:
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped
),
)
]
)
return None

View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
class Step3ToolParser(ToolParser):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN = "<tool_calls_begin>"
TOOL_CALLS_END = "<tool_calls_end>"
TOOL_CALL_BEGIN = "<tool_call_begin>"
TOOL_CALL_END = "<tool_call_end>"
TOOL_SEP = "<tool_sep>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(
action_text: str,
) -> tuple[str | None, dict[str, str] | None]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
action_text,
)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = (
lower_val == "true"
if lower_val in ("true", "false")
else value
)
elif typ == "null":
params[key] = None if value.lower() == "null" else value
break
return params
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# The main loop processes the stream from the last known position.
while True:
if self.position >= len(current_text):
return None # We've processed the entire stream.
unprocessed_text = current_text[self.position :]
# STATE: After all tools are done, all subsequent text is content.
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
# STATE: Before the tool block has started.
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue # Token consumed, re-loop.
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if (
self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
and unprocessed_text
):
return None # It's a prefix, wait.
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
# STATE: Inside the main tool block.
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
# Check if we are between tool calls.
tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
self.current_tool_id
].get("finished")
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
# STATE: Parsing an active tool call.
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
self.current_tool_id
].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(tool_body)
if not function_name:
return None
tool_call_arr = {"name": function_name, "parameters": arguments or {}}
# Send the function name as soon as it's parsed.
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(name=function_name),
)
]
)
# Update our internal state with the latest parsed arguments.
self.prev_tool_call_arr[self.current_tool_id].update( # noqa: E501
tool_call_arr
)
# Only send arguments when the tool call is complete.
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request,
)
if final_args:
final_args_json = json.dumps(final_args, ensure_ascii=False)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=final_args_json
),
)
]
)
# If tool is not finished, return None to wait for more tokens.
return None
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict, request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(
ToolCall(
function=FunctionCall(name=function_name, arguments=params_str)
)
)
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

229
vllm/tool_parsers/utils.py Normal file
View File

@@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any
import partial_json_parser
from openai.types.responses import (
FunctionTool,
ToolChoiceFunction,
)
from openai.types.responses.tool import Tool
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolsParam,
)
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ""
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, "", 1)
return diff
def find_all_indices(string: str, substring: str) -> list[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices
# partial_json_parser doesn't support extra data and
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
def consume_space(i: int, s: str) -> int:
while i < len(s) and s[i].isspace():
i += 1
return i
def _extract_tool_info(
tool: Tool | ChatCompletionToolsParam,
) -> tuple[str, dict[str, Any] | None]:
if isinstance(tool, FunctionTool):
return tool.name, tool.parameters
elif isinstance(tool, ChatCompletionToolsParam):
return tool.function.name, tool.function.parameters
else:
raise TypeError(f"Unsupported tool type: {type(tool)}")
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
name, params = _extract_tool_info(tool)
params = params if params else {"type": "object", "properties": {}}
return {
"properties": {
"name": {"type": "string", "enum": [name]},
"parameters": params,
},
"required": ["name", "parameters"],
}
def _get_tool_schema_defs(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
all_defs: dict[str, dict[str, Any]] = {}
for tool in tools:
_, params = _extract_tool_info(tool)
if params is None:
continue
defs = params.pop("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has multiple schemas, "
"which is not supported."
)
all_defs[def_name] = def_schema
return all_defs
def _get_json_schema_from_tools(
tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
def get_json_schema_from_tools(
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
) -> str | dict | None:
# tool_choice: "none"
if tool_choice in ("none", None) or tools is None:
return None
# tool_choice: Forced Function (Responses)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ToolChoiceFunction
):
tool_name = tool_choice.name
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].parameters
# tool_choice: Forced Function (ChatCompletion)
if (not isinstance(tool_choice, str)) and isinstance(
tool_choice, ChatCompletionNamedToolChoiceParam
):
tool_name = tool_choice.function.name
tool_map = {
tool.function.name: tool
for tool in tools
if isinstance(tool, ChatCompletionToolsParam)
}
if tool_name not in tool_map:
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
return tool_map[tool_name].function.parameters
# tool_choice: "required"
if tool_choice == "required":
return _get_json_schema_from_tools(tools)
# tool_choice: "auto"
return None

View File

@@ -0,0 +1,556 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import json
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode
self.prev_tool_calls: list[dict] = []
self.current_tool_id = -1
self.current_tool_name_sent = False
self.streamed_args: list[str] = [] # Track arguments sent for each tool
# For backward compatibility with tests
self.current_tools_sent: list[bool] = []
# For backward compatibility with serving code
self.prev_tool_call_arr = []
# Regex patterns for preprocessing
self.json_code_block_patterns = [
r"```(?:json)?\s*([\s\S]*?)```",
r"\[TOOL_CALLS\]([\s\S]*?)(?=\n|$)",
r"<tool_call>([\s\S]*?)</tool_call>",
]
self.thinking_tag_pattern = r"</think>([\s\S]*)"
# Define streaming state type to be initialized later
self.streaming_state: dict[str, Any] = {
"current_tool_index": -1,
"tool_ids": [],
"sent_tools": [],
}
def preprocess_model_output(
self, model_output: str
) -> tuple[Optional[str], Optional[str]]:
"""
Preprocess the model output to extract content and potential tool calls.
Returns:
Tuple of (content, potential_tool_calls_json)
"""
# Check for thinking tag
thinking_match = re.search(self.thinking_tag_pattern, model_output)
if thinking_match:
content = model_output[: thinking_match.start() + len("</think>")].strip()
thinking_content = thinking_match.group(1).strip()
# Try to parse the thinking content as JSON
try:
json.loads(thinking_content)
return content, thinking_content
except json.JSONDecodeError:
# If can't parse as JSON, look for JSON code blocks
for json_pattern in self.json_code_block_patterns:
json_matches = re.findall(json_pattern, thinking_content)
if json_matches:
for json_str in json_matches:
try:
json.loads(json_str)
return content, json_str
except json.JSONDecodeError:
continue
# Check for JSON code blocks in the entire output
for json_pattern in self.json_code_block_patterns:
json_matches = re.findall(json_pattern, model_output)
if json_matches:
for json_str in json_matches:
try:
json.loads(json_str)
# Extract content by removing the JSON code block
content = re.sub(json_pattern, "", model_output).strip()
return content, json_str
except json.JSONDecodeError:
continue
# If the entire output is a valid JSON array or looks like one, treat it as tool calls
if model_output.strip().startswith("["):
try:
json.loads(model_output)
return None, model_output
except json.JSONDecodeError:
# Even if it's not valid JSON yet, it might be a tool call in progress
if (
"{" in model_output
and "name" in model_output
and "arguments" in model_output
):
return None, model_output
# If no tool calls found, return the original output as content
return model_output, None
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract tool calls from a complete model output.
"""
try:
# Preprocess the model output
content, potential_tool_calls = self.preprocess_model_output(model_output)
if not potential_tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=content
)
# Parse the potential tool calls as JSON
tool_calls_data = json.loads(potential_tool_calls)
# Ensure it's an array
if not isinstance(tool_calls_data, list):
logger.debug("Tool calls data is not an array")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=content or model_output,
)
tool_calls: list[ToolCall] = []
for idx, call in enumerate(tool_calls_data):
if (
not isinstance(call, dict)
or "name" not in call
or "arguments" not in call
):
logger.debug("Invalid tool call format at index %d", idx)
continue
tool_call = ToolCall(
id=f"call_{idx}_{random_uuid()}",
type="function",
function=FunctionCall(
name=call["name"],
arguments=(
json.dumps(call["arguments"])
if isinstance(call["arguments"], dict)
else call["arguments"]
),
),
)
tool_calls.append(tool_call)
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=content,
)
except Exception as e:
logger.exception("Error extracting tool calls: %s", str(e))
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Extract tool calls for streaming mode.
"""
# First, check for a definitive start of a tool call block.
# This prevents premature parsing of incomplete output.
stripped_text = current_text.strip()
preprocessed_content, preprocessed_tool_calls = self.preprocess_model_output(
current_text
)
# For JSON code blocks, we need to detect them earlier, even if incomplete
has_potential_json_block = (
"```json" in current_text
or "```\n[" in current_text
or "[TOOL_CALLS]" in current_text
or "<tool_call>" in current_text
)
is_tool_call_block = (
stripped_text.startswith("[")
or stripped_text.startswith("<tool_call>")
or stripped_text.startswith("[TOOL_CALLS]")
or
# Check if we have thinking tags with JSON-like content following
("</think>[" in current_text)
or
# Check if the text contains a JSON array after preprocessing
preprocessed_tool_calls is not None
or
# For JSON code blocks, detect early if we see enough structure
(
has_potential_json_block
and '"name"' in current_text
and '"arguments"' in current_text
)
)
if not is_tool_call_block:
return DeltaMessage(content=delta_text)
try:
# Initialize streaming state if not exists
if not hasattr(self, "streaming_state"):
self.streaming_state = {
"current_tool_index": -1,
"tool_ids": [],
"sent_tools": [], # Track complete state of each tool
}
# Try parsing as JSON to check for complete tool calls
try:
# Use preprocessed tool calls if available
tool_calls_text = (
preprocessed_tool_calls if preprocessed_tool_calls else current_text
)
parsed_tools = json.loads(tool_calls_text)
if isinstance(parsed_tools, list):
# Update our tool array for next time
self.prev_tool_call_arr = parsed_tools
except json.JSONDecodeError:
# Not complete JSON yet, use regex for partial parsing
pass
# Check for test-specific state setup (current_tools_sent)
# This handles the case where tests manually set current_tools_sent
if (
hasattr(self, "current_tools_sent") # type: ignore
and len(self.current_tools_sent) > 0
):
# If current_tools_sent is set to [False], it means the test wants us to send the name
if (
len(self.current_tools_sent) == 1
and self.current_tools_sent[0] is False
):
# Extract the function name using regex
name_pattern = r'"name"\s*:\s*"([^"]+)"'
name_match = re.search(name_pattern, current_text)
if name_match:
function_name = name_match.group(1)
# The test expects us to send just the name first
tool_id = make_tool_call_id()
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True), # type: ignore
)
]
)
# Update state to reflect that we've sent the name
self.current_tools_sent = [True]
self.current_tool_id = 0
self.streaming_state["current_tool_index"] = 0
if len(self.streaming_state["sent_tools"]) == 0:
self.streaming_state["sent_tools"].append(
{
"sent_name": True,
"sent_arguments_prefix": False,
"sent_arguments": "",
}
)
else:
self.streaming_state["sent_tools"][0]["sent_name"] = True
self.current_tool_name_sent = True
return delta
# Use regex to identify tool calls in the output
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
search_text = (
preprocessed_tool_calls if preprocessed_tool_calls else current_text
)
# For JSON code blocks that aren't complete yet, try to extract the JSON content
if not preprocessed_tool_calls and has_potential_json_block:
# Try to extract the JSON array from within the code block
json_match = re.search(
r"```(?:json)?\s*([\s\S]*?)(?:```|$)", current_text
)
if json_match:
potential_json = json_match.group(1).strip()
# Use this as search text even if it's incomplete
if potential_json.startswith("[") and (
'"name"' in potential_json and '"arguments"' in potential_json
):
search_text = potential_json
# Try to find complete tool names first
name_pattern = r'"name"\s*:\s*"([^"]+)"'
name_matches = list(re.finditer(name_pattern, search_text))
tool_count = len(name_matches)
# If no complete tool names found, check for partial tool names
if tool_count == 0:
# Check if we're in the middle of parsing a tool name
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
partial_matches = list(re.finditer(partial_name_pattern, search_text))
if partial_matches:
# We have a partial tool name - not ready to emit yet
return None
else:
# No tools found at all
return None
# Ensure our state arrays are large enough
while len(self.streaming_state["sent_tools"]) < tool_count:
self.streaming_state["sent_tools"].append(
{
"sent_name": False,
"sent_arguments_prefix": False,
"sent_arguments": "",
}
)
while len(self.streaming_state["tool_ids"]) < tool_count:
self.streaming_state["tool_ids"].append(None)
# Determine if we need to move to a new tool
current_idx = self.streaming_state["current_tool_index"]
# If we haven't processed any tool yet or current tool is complete, move to next
if current_idx == -1 or current_idx < tool_count - 1:
next_idx = current_idx + 1
# If tool at next_idx has not been sent yet
if (
next_idx < tool_count
and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
):
# Update indexes
self.streaming_state["current_tool_index"] = next_idx
self.current_tool_id = next_idx # For backward compatibility
current_idx = next_idx
# Extract the tool name
tool_name = name_matches[current_idx].group(1)
# Generate ID and send tool name
tool_id = f"call_{current_idx}_{random_uuid()}"
self.streaming_state["tool_ids"][current_idx] = tool_id
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
type="function",
id=tool_id,
function=DeltaFunctionCall(name=tool_name).model_dump(
exclude_none=True
), # type: ignore
)
]
)
self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
self.current_tool_name_sent = True # For backward compatibility
# Keep track of streamed args for backward compatibility
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
return delta
# Process arguments for the current tool
if current_idx >= 0 and current_idx < tool_count:
# Support both regular and empty argument objects
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern = (
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
)
empty_args_match = re.search(empty_args_pattern, search_text)
# Check if this tool has empty arguments
if empty_args_match and empty_args_match.start() > 0:
# Find which tool this empty arguments belongs to
empty_args_tool_idx = 0
for i in range(tool_count):
if i == current_idx:
# If this is our current tool and it has empty arguments
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
]:
# Send empty object
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = "{}"
# Update streamed_args for backward compatibility
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{}"
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments="{}"
).model_dump(exclude_none=True), # type: ignore
)
]
)
# Move to next tool if available
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"
]
return delta
# Extract arguments for current tool using regex for non-empty arguments
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
args_matches = list(re.finditer(args_pattern, search_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
# Handle transition between tools
is_last_tool = current_idx == tool_count - 1
# For multiple tools, extract only the arguments for the current tool
if tool_count > 1:
# Parse the entire JSON structure to properly extract arguments for each tool
try:
parsed_tools = json.loads(search_text)
if isinstance(parsed_tools, list) and current_idx < len(
parsed_tools
):
current_tool = parsed_tools[current_idx]
if isinstance(current_tool.get("arguments"), dict):
args_text = json.dumps(current_tool["arguments"])
else:
args_text = str(current_tool.get("arguments", "{}"))
except (json.JSONDecodeError, KeyError, IndexError):
# Fallback to regex-based extraction
pass
# If arguments haven't been sent yet
sent_args = self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
]
# If we haven't sent the opening bracket yet
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] and args_text.startswith("{"):
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"
] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = "{"
# Update streamed_args for backward compatibility
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{"
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments="{"
).model_dump(exclude_none=True), # type: ignore
)
]
)
return delta
# If we need to send more arguments
if args_text.startswith(sent_args):
# Calculate what part of arguments we need to send
args_diff = args_text[len(sent_args) :]
if args_diff:
# Update our state
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"
] = args_text
# Update streamed_args for backward compatibility
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += args_diff
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments=args_diff
).model_dump(exclude_none=True), # type: ignore
)
]
)
return delta
# If the tool's arguments are complete, check if we need to move to the next tool
if args_text.endswith("}") and args_text == sent_args:
# This tool is complete, move to the next one in the next iteration
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"
] # For compatibility
# If we got here, we couldn't determine what to stream next
return None
except Exception as e:
logger.exception(f"Error in streaming tool calls: {e}")
# If we encounter an error, just return the delta text as regular content
return DeltaMessage(content=delta_text)