Sync from v0.13
This commit is contained in:
150
vllm/tool_parsers/__init__.py
Normal file
150
vllm/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
|
||||
__all__ = ["ToolParser", "ToolParserManager"]
|
||||
|
||||
|
||||
"""
|
||||
Register a lazy module mapping.
|
||||
|
||||
Example:
|
||||
ToolParserManager.register_lazy_module(
|
||||
name="kimi_k2",
|
||||
module_path="vllm.tool_parsers.kimi_k2_parser",
|
||||
class_name="KimiK2ToolParser",
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_TOOL_PARSERS_TO_REGISTER = {
|
||||
"deepseek_v3": ( # name
|
||||
"deepseekv3_tool_parser", # filename
|
||||
"DeepSeekV3ToolParser", # class_name
|
||||
),
|
||||
"deepseek_v31": (
|
||||
"deepseekv31_tool_parser",
|
||||
"DeepSeekV31ToolParser",
|
||||
),
|
||||
"deepseek_v32": (
|
||||
"deepseekv32_tool_parser",
|
||||
"DeepSeekV32ToolParser",
|
||||
),
|
||||
"ernie45": (
|
||||
"ernie45_tool_parser",
|
||||
"Ernie45ToolParser",
|
||||
),
|
||||
"glm45": (
|
||||
"glm4_moe_tool_parser",
|
||||
"Glm4MoeModelToolParser",
|
||||
),
|
||||
"granite-20b-fc": (
|
||||
"granite_20b_fc_tool_parser",
|
||||
"Granite20bFCToolParser",
|
||||
),
|
||||
"granite": (
|
||||
"granite_tool_parser",
|
||||
"GraniteToolParser",
|
||||
),
|
||||
"hermes": (
|
||||
"hermes_tool_parser",
|
||||
"Hermes2ProToolParser",
|
||||
),
|
||||
"hunyuan_a13b": (
|
||||
"hunyuan_a13b_tool_parser",
|
||||
"HunyuanA13BToolParser",
|
||||
),
|
||||
"internlm": (
|
||||
"internlm2_tool_parser",
|
||||
"Internlm2ToolParser",
|
||||
),
|
||||
"jamba": (
|
||||
"jamba_tool_parser",
|
||||
"JambaToolParser",
|
||||
),
|
||||
"kimi_k2": (
|
||||
"kimi_k2_tool_parser",
|
||||
"KimiK2ToolParser",
|
||||
),
|
||||
"llama3_json": (
|
||||
"llama_tool_parser",
|
||||
"Llama3JsonToolParser",
|
||||
),
|
||||
"llama4_json": (
|
||||
"llama_tool_parser",
|
||||
"Llama3JsonToolParser",
|
||||
),
|
||||
"llama4_pythonic": (
|
||||
"llama4_pythonic_tool_parser",
|
||||
"Llama4PythonicToolParser",
|
||||
),
|
||||
"longcat": (
|
||||
"longcat_tool_parser",
|
||||
"LongcatFlashToolParser",
|
||||
),
|
||||
"minimax_m2": (
|
||||
"minimax_m2_tool_parser",
|
||||
"MinimaxM2ToolParser",
|
||||
),
|
||||
"minimax": (
|
||||
"minimax_tool_parser",
|
||||
"MinimaxToolParser",
|
||||
),
|
||||
"mistral": (
|
||||
"mistral_tool_parser",
|
||||
"MistralToolParser",
|
||||
),
|
||||
"olmo3": (
|
||||
"olmo3_tool_parser",
|
||||
"Olmo3PythonicToolParser",
|
||||
),
|
||||
"openai": (
|
||||
"openai_tool_parser",
|
||||
"OpenAIToolParser",
|
||||
),
|
||||
"phi4_mini_json": (
|
||||
"phi4mini_tool_parser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
),
|
||||
"pythonic": (
|
||||
"pythonic_tool_parser",
|
||||
"PythonicToolParser",
|
||||
),
|
||||
"qwen3_coder": (
|
||||
"qwen3coder_tool_parser",
|
||||
"Qwen3CoderToolParser",
|
||||
),
|
||||
"qwen3_xml": (
|
||||
"qwen3xml_tool_parser",
|
||||
"Qwen3XMLToolParser",
|
||||
),
|
||||
"seed_oss": (
|
||||
"seed_oss_tool_parser",
|
||||
"SeedOssToolParser",
|
||||
),
|
||||
"step3": (
|
||||
"step3_tool_parser",
|
||||
"Step3ToolParser",
|
||||
),
|
||||
"xlam": (
|
||||
"xlam_tool_parser",
|
||||
"xLAMToolParser",
|
||||
),
|
||||
"gigachat3": (
|
||||
"gigachat3_tool_parser",
|
||||
"GigaChat3ToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_lazy_tool_parsers():
|
||||
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
|
||||
module_path = f"vllm.tool_parsers.{file_name}"
|
||||
ToolParserManager.register_lazy_module(name, module_path, class_name)
|
||||
|
||||
|
||||
register_lazy_tool_parsers()
|
||||
273
vllm/tool_parsers/abstract_tool_parser.py
Normal file
273
vllm/tool_parsers/abstract_tool_parser.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from openai.types.responses.response_format_text_json_schema_config import (
|
||||
ResponseFormatTextJSONSchemaConfig,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
ResponsesRequest,
|
||||
ResponseTextConfig,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.utils import get_json_schema_from_tools
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import import_from_path
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
if not request.tools:
|
||||
return request
|
||||
json_schema_from_tool = get_json_schema_from_tools(
|
||||
tool_choice=request.tool_choice, tools=request.tools
|
||||
)
|
||||
# Set structured output params for tool calling
|
||||
if json_schema_from_tool is not None:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
request.structured_outputs = StructuredOutputsParams()
|
||||
# tool_choice: "Forced Function" or "required" will override
|
||||
# structured output json settings to make tool calling work correctly
|
||||
request.structured_outputs.json = json_schema_from_tool
|
||||
if isinstance(request, ResponsesRequest):
|
||||
request.text = ResponseTextConfig()
|
||||
request.text.format = ResponseFormatTextJSONSchemaConfig(
|
||||
name="tool_calling_response",
|
||||
schema=json_schema_from_tool,
|
||||
type="json_schema",
|
||||
description="Response format for tool calling",
|
||||
strict=True,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!"
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
|
||||
)
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
"""
|
||||
Central registry for ToolParser implementations.
|
||||
|
||||
Supports two modes:
|
||||
- Eager (immediate) registration via `register_module`
|
||||
- Lazy registration via `register_lazy_module`
|
||||
"""
|
||||
|
||||
tool_parsers: dict[str, type[ToolParser]] = {}
|
||||
lazy_parsers: dict[str, tuple[str, str]] = {} # name -> (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name: str) -> type[ToolParser]:
|
||||
"""
|
||||
Retrieve a registered or lazily registered ToolParser class.
|
||||
|
||||
If the parser is lazily registered,
|
||||
it will be imported and cached on first access.
|
||||
Raises KeyError if not found.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
if name in cls.lazy_parsers:
|
||||
return cls._load_lazy_parser(name)
|
||||
|
||||
raise KeyError(f"Tool parser '{name}' not found.")
|
||||
|
||||
@classmethod
|
||||
def _load_lazy_parser(cls, name: str) -> type[ToolParser]:
|
||||
"""Import and register a lazily loaded parser."""
|
||||
module_path, class_name = cls.lazy_parsers[name]
|
||||
try:
|
||||
mod = importlib.import_module(module_path)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
if not issubclass(parser_cls, ToolParser):
|
||||
raise TypeError(
|
||||
f"{class_name} in {module_path} is not a ToolParser subclass."
|
||||
)
|
||||
cls.tool_parsers[name] = parser_cls # cache
|
||||
return parser_cls
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to import lazy tool parser '%s' from %s: %s",
|
||||
name,
|
||||
module_path,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def _register_module(
|
||||
cls,
|
||||
module: type[ToolParser],
|
||||
module_name: str | list[str] | None = None,
|
||||
force: bool = True,
|
||||
) -> None:
|
||||
"""Register a ToolParser class immediately."""
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f"module must be subclass of ToolParser, but got {type(module)}"
|
||||
)
|
||||
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
|
||||
if isinstance(module_name, str):
|
||||
module_names = [module_name]
|
||||
elif is_list_of(module_name, str):
|
||||
module_names = module_name
|
||||
else:
|
||||
raise TypeError("module_name must be str, list[str], or None.")
|
||||
|
||||
for name in module_names:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed = cls.tool_parsers[name]
|
||||
raise KeyError(f"{name} is already registered at {existed.__module__}")
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_lazy_module(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""
|
||||
Register a lazy module mapping.
|
||||
|
||||
Example:
|
||||
ToolParserManager.register_lazy_module(
|
||||
name="kimi_k2",
|
||||
module_path="vllm.tool_parsers.kimi_k2_parser",
|
||||
class_name="KimiK2ToolParser",
|
||||
)
|
||||
"""
|
||||
cls.lazy_parsers[name] = (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: str | list[str] | None = None,
|
||||
force: bool = True,
|
||||
module: type[ToolParser] | None = None,
|
||||
) -> type[ToolParser] | Callable[[type[ToolParser]], type[ToolParser]]:
|
||||
"""
|
||||
Register module immediately or lazily (as a decorator).
|
||||
|
||||
Usage:
|
||||
@ToolParserManager.register_module("kimi_k2")
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
...
|
||||
|
||||
Or:
|
||||
ToolParserManager.register_module(module=SomeToolParser)
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||
|
||||
# Immediate registration
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# Decorator usage
|
||||
def _decorator(obj: type[ToolParser]) -> type[ToolParser]:
|
||||
module_path = obj.__module__
|
||||
class_name = obj.__name__
|
||||
|
||||
if isinstance(name, str):
|
||||
names = [name]
|
||||
elif name is not None and is_list_of(name, str):
|
||||
names = name
|
||||
else:
|
||||
names = [class_name]
|
||||
|
||||
for n in names:
|
||||
# Lazy mapping only: do not import now
|
||||
cls.lazy_parsers[n] = (module_path, class_name)
|
||||
|
||||
return obj
|
||||
|
||||
return _decorator
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list[str]:
|
||||
"""Return names of all eagerly and lazily registered tool parsers."""
|
||||
return sorted(set(cls.tool_parsers.keys()) | set(cls.lazy_parsers.keys()))
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""Import a user-defined parser file from arbitrary path."""
|
||||
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
try:
|
||||
import_from_path(module_name, plugin_path)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load module '%s' from %s.", module_name, plugin_path
|
||||
)
|
||||
388
vllm/tool_parsers/deepseekv31_tool_parser.py
Normal file
388
vllm/tool_parsers/deepseekv31_tool_parser.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV31ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (
|
||||
self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3.1 Tool parser could not locate tool call "
|
||||
"start/end tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=function_args
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content = model_output[: model_output.find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
|
||||
self.tool_calls_end_token, ""
|
||||
)
|
||||
try:
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = (
|
||||
full_text.split(self.tool_call_start_token)[-1]
|
||||
.split(self.tool_call_end_token)[0]
|
||||
.rstrip()
|
||||
)
|
||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count
|
||||
):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count
|
||||
):
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
diff = (
|
||||
diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion
|
||||
)
|
||||
if current_tool_call_matches:
|
||||
tool_name, tool_args = current_tool_call_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(tool_call_portion)
|
||||
)
|
||||
if current_tool_call_name_matches:
|
||||
tool_name = current_tool_call_name_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: str | None = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (
|
||||
DeltaMessage(content=delta_text)
|
||||
if text_portion is not None
|
||||
else None
|
||||
)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug(
|
||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
||||
)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything."
|
||||
)
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (
|
||||
isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)
|
||||
):
|
||||
delta_arguments = cur_arguments[len(prev_arguments) :]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
591
vllm/tool_parsers/deepseekv32_tool_parser.py
Normal file
591
vllm/tool_parsers/deepseekv32_tool_parser.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV32ToolParser(ToolParser):
|
||||
"""
|
||||
example tool call content:
|
||||
<|DSML|function_calls>
|
||||
<|DSML|invoke name="get_weather">
|
||||
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
|
||||
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
<|DSML|invoke name="get_weather">
|
||||
<|DSML|parameter name="location" string="true">北京</|DSML|parameter>
|
||||
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
||||
</|DSML|invoke>
|
||||
</|DSML|function_calls>
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
# Sentinel tokens
|
||||
self.dsml_token: str = "|DSML|"
|
||||
self.dsml_start_check: str = "<" + self.dsml_token
|
||||
self.tool_call_start_token: str = "<|DSML|function_calls>"
|
||||
self.tool_call_end_token: str = "</|DSML|function_calls>"
|
||||
self.invoke_start_prefix: str = "<|DSML|invoke name="
|
||||
self.invoke_end_token: str = "</|DSML|invoke>"
|
||||
self.parameter_prefix: str = "<|DSML|parameter name="
|
||||
self.parameter_end_token: str = "</|DSML|parameter>"
|
||||
|
||||
# Streaming state variables
|
||||
self.current_tool_name_sent: bool = False
|
||||
# Override base class type - we use string IDs for tool calls
|
||||
self.current_tool_id: str | None = None # type: ignore
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.is_tool_call_started: bool = False
|
||||
self.failed_count: int = 0
|
||||
|
||||
# Initialize streaming state variables
|
||||
self.current_tool_index: int = 0
|
||||
self.invoke_index: int = 0
|
||||
self.header_sent: bool = False
|
||||
self.current_function_name: str | None = None
|
||||
self.current_param_name: str | None = None
|
||||
self.current_param_value: str = ""
|
||||
self.param_count: int = 0
|
||||
self.in_param: bool = False
|
||||
self.in_function: bool = False
|
||||
self.json_started: bool = False
|
||||
self.json_closed: bool = False
|
||||
self.accumulated_params: dict = {}
|
||||
self.streaming_request: ChatCompletionRequest | None = None
|
||||
|
||||
# Enhanced streaming state - reset for each new message
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Regex patterns for complete parsing
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL
|
||||
)
|
||||
self.invoke_complete_regex = re.compile(
|
||||
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
|
||||
)
|
||||
self.parameter_complete_regex = re.compile(
|
||||
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.invoke_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
# Store accumulated parameters for type conversion
|
||||
self.accumulated_params = {}
|
||||
self.streaming_request = None
|
||||
# Clear previous tool call history to avoid state pollution
|
||||
self.prev_tool_call_arr.clear()
|
||||
|
||||
def _parse_invoke_params(self, invoke_str: str) -> dict | None:
|
||||
param_dict = dict()
|
||||
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
|
||||
param_dict[param_name] = param_val
|
||||
return param_dict
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
# Quick check
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
# Find all complete tool_call blocks
|
||||
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
|
||||
# Find all invokes within this tool_call
|
||||
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
|
||||
tool_call_match
|
||||
):
|
||||
param_dict = self._parse_invoke_params(invoke_content)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=invoke_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Extract content before first tool call
|
||||
first_tool_idx = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _extract_name(self, name_str: str) -> str:
|
||||
"""Extract name from quoted string."""
|
||||
name_str = name_str.strip()
|
||||
if (
|
||||
name_str.startswith('"')
|
||||
and name_str.endswith('"')
|
||||
or name_str.startswith("'")
|
||||
and name_str.endswith("'")
|
||||
):
|
||||
return name_str[1:-1]
|
||||
return name_str
|
||||
|
||||
def _extract_param_name(self, input_str: str) -> str:
|
||||
"""Extract param name"""
|
||||
start = input_str.find('"') + 1
|
||||
end = input_str.find('"', start)
|
||||
return input_str[start:end] if start > 0 and end > start else input_str
|
||||
|
||||
def _convert_param_value(self, value: str, param_type: str) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
if value.lower() == "null":
|
||||
return None
|
||||
|
||||
param_type = param_type.lower()
|
||||
if param_type in ["string", "str", "text"]:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["number", "float"]:
|
||||
try:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif param_type in ["object", "array"]:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
else:
|
||||
# Try JSON parse first, fallback to string
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output."""
|
||||
|
||||
# Store request for type conversion
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
# If no delta text, return None unless it's an EOS token after tools
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
if delta_token_ids:
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text)
|
||||
)
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token
|
||||
) - current_text.count(self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta for finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
invoke_ends = current_text.count(self.invoke_end_token)
|
||||
if invoke_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.in_function = False # Now we can safely set this to False
|
||||
self.accumulated_params = {}
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if self.dsml_token in current_text:
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.dsml_start_check in delta_text:
|
||||
content_before = delta_text[
|
||||
: delta_text.index(self.dsml_start_check)
|
||||
]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (
|
||||
current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""
|
||||
):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
if delta_text.endswith("<"):
|
||||
return DeltaMessage(content=delta_text[:-1])
|
||||
if previous_text and previous_text.endswith("<"):
|
||||
return DeltaMessage(content="<" + delta_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
invoke_starts_count = current_text.count(self.invoke_start_prefix)
|
||||
if self.current_tool_index >= invoke_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# Find the current tool call portion
|
||||
invoke_start_positions: list[int] = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = current_text.find(self.invoke_start_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
invoke_start_positions.append(idx)
|
||||
idx += len(self.invoke_start_prefix)
|
||||
|
||||
if self.current_tool_index >= len(invoke_start_positions):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
invoke_start_idx = invoke_start_positions[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
|
||||
if invoke_end_idx == -1:
|
||||
tool_text = current_text[invoke_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[
|
||||
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
|
||||
]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.invoke_start_prefix in tool_text:
|
||||
func_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
# Find the end quote for the function name
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
function_name_raw = tool_text[func_start:func_end]
|
||||
self.current_function_name = self._extract_name(function_name_raw)
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# Each tool call should be recorded regardless of function name
|
||||
# Ensure we don't add the same tool call index multiple times
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_index:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}", # Placeholder, will be updated later
|
||||
}
|
||||
)
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if self.in_function and not self.json_started:
|
||||
self.json_started = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.invoke_end_token in tool_text:
|
||||
# Count total parameters in the tool text
|
||||
total_param_count = tool_text.count(self.parameter_prefix)
|
||||
|
||||
# Only close JSON if all parameters have been processed
|
||||
if self.param_count >= total_param_count:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract complete tool call
|
||||
# Find the invoke content
|
||||
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
invoke_content_end = tool_text.find(
|
||||
self.invoke_end_token, invoke_start
|
||||
)
|
||||
if invoke_content_end != -1:
|
||||
invoke_content = tool_text[invoke_start:invoke_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
invoke_params = self._parse_invoke_params(invoke_content)
|
||||
if invoke_params and self.current_tool_index < len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
# Update existing entry in prev_tool_call_arr
|
||||
self.prev_tool_call_arr[self.current_tool_index][
|
||||
"arguments"
|
||||
] = json.dumps(invoke_params, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass # Ignore parsing errors during streaming
|
||||
|
||||
result = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Reset state for next tool
|
||||
self.json_closed = True
|
||||
self.in_function = False
|
||||
self.accumulated_params = {}
|
||||
|
||||
logger.debug("[M2_STREAMING] Tool call completed")
|
||||
|
||||
return result
|
||||
else:
|
||||
# Don't close JSON yet, continue processing parameters
|
||||
return None
|
||||
|
||||
# Look for parameters
|
||||
# Find all parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if (
|
||||
not self.in_param
|
||||
and self.param_count < len(param_starts)
|
||||
and len(param_starts) > self.param_count
|
||||
):
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
param_name_raw = remaining[:name_end]
|
||||
self.current_param_name = self._extract_param_name(param_name_raw)
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
# No closing tag, look for next parameter or function end
|
||||
next_param_idx = value_text.find(self.parameter_prefix)
|
||||
func_end_idx = value_text.find(self.invoke_end_token)
|
||||
|
||||
if next_param_idx != -1 and (
|
||||
func_end_idx == -1 or next_param_idx < func_end_idx
|
||||
):
|
||||
param_end_idx = next_param_idx
|
||||
elif func_end_idx != -1:
|
||||
param_end_idx = func_end_idx
|
||||
else:
|
||||
# Neither found, check if tool call is complete
|
||||
if self.invoke_end_token in tool_text:
|
||||
# Tool call and parameter is complete
|
||||
param_end_idx = len(value_text)
|
||||
else:
|
||||
# Still streaming, wait for more content
|
||||
return None
|
||||
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = {}
|
||||
if self.streaming_request and self.streaming_request.tools:
|
||||
for tool in self.streaming_request.tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == self.current_function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
params = tool.function.parameters
|
||||
if (
|
||||
isinstance(params, dict)
|
||||
and "properties" in params
|
||||
):
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
self.current_param_name in param_config
|
||||
and isinstance(param_config[self.current_param_name], dict)
|
||||
and "type" in param_config[self.current_param_name]
|
||||
):
|
||||
param_type = param_config[self.current_param_name]["type"]
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
# Build JSON fragment based on the converted type
|
||||
# Use json.dumps to properly serialize the value
|
||||
serialized_value = json.dumps(
|
||||
converted_value, ensure_ascii=False
|
||||
)
|
||||
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
f'"{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
else:
|
||||
json_fragment = (
|
||||
f', "{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments=json_fragment),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
390
vllm/tool_parsers/deepseekv3_tool_parser.py
Normal file
390
vllm/tool_parsers/deepseekv3_tool_parser.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV3ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (
|
||||
self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
tool_type, function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type=tool_type,
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=function_args
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content = model_output[: model_output.find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
|
||||
self.tool_calls_end_token, ""
|
||||
)
|
||||
try:
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = (
|
||||
full_text.split(self.tool_call_start_token)[-1]
|
||||
.split(self.tool_call_end_token)[0]
|
||||
.rstrip()
|
||||
)
|
||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count
|
||||
):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count
|
||||
):
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
diff = (
|
||||
diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion
|
||||
)
|
||||
if current_tool_call_matches:
|
||||
tool_type, tool_name, tool_args = current_tool_call_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(tool_call_portion)
|
||||
)
|
||||
if current_tool_call_name_matches:
|
||||
tool_type, tool_name = current_tool_call_name_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: str | None = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (
|
||||
DeltaMessage(content=delta_text)
|
||||
if text_portion is not None
|
||||
else None
|
||||
)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug(
|
||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
||||
)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything."
|
||||
)
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (
|
||||
isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)
|
||||
):
|
||||
delta_arguments = cur_arguments[len(prev_arguments) :]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
210
vllm/tool_parsers/ernie45_tool_parser.py
Normal file
210
vllm/tool_parsers/ernie45_tool_parser.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Ernie45ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
"""
|
||||
Ernie thinking model format:
|
||||
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
|
||||
"""
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.think_end_token = "</think>"
|
||||
self.response_start_token: str = "<response>"
|
||||
self.response_end_token: str = "</response>"
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
self.newline_token: str = "<0x0A>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||
self.parser_token_ids = [
|
||||
self.think_end_token_id,
|
||||
self.response_start_token_id,
|
||||
self.response_end_token_id,
|
||||
]
|
||||
|
||||
self._buffer = ""
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
tool_call_json_list = self.tool_call_regex.findall(model_output)
|
||||
|
||||
tool_calls = []
|
||||
for tool_call_json in tool_call_json_list:
|
||||
tool_call_dict = json.loads(tool_call_json)
|
||||
args_str = json.dumps(
|
||||
tool_call_dict.get("arguments", {}), ensure_ascii=False
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call_dict.get("name", ""),
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content = model_output[
|
||||
: model_output.find(self.tool_calls_start_token)
|
||||
].rstrip("\n")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
self._buffer += delta_text
|
||||
cur_text = self._buffer
|
||||
start_idx = cur_text.find(self.tool_call_start_token)
|
||||
if start_idx == -1:
|
||||
self._buffer = ""
|
||||
# At least one toolcall has been completed
|
||||
if self.current_tool_id > 0:
|
||||
cur_text = ""
|
||||
if self.current_tool_id == -1 and all(
|
||||
token_id == self.newline_token_id for token_id in previous_token_ids
|
||||
):
|
||||
cur_text = cur_text.strip("\n")
|
||||
|
||||
# handle <response> </response> when tool_call is not triggered
|
||||
# cur_text === delta_text
|
||||
content = cur_text
|
||||
if self.response_start_token_id in delta_token_ids:
|
||||
content = content.lstrip("\n")
|
||||
response_start_idx = content.find(self.response_start_token)
|
||||
content = content[response_start_idx + len(self.response_start_token) :]
|
||||
# if have </response>, remove it
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
if response_end_idx != -1:
|
||||
content = content[:response_end_idx]
|
||||
elif self.response_end_token_id in delta_token_ids:
|
||||
response_end_idx = content.rfind(self.response_end_token)
|
||||
content = content[:response_end_idx]
|
||||
# remove \n after </think> or <response> or </response>
|
||||
if (
|
||||
len(previous_token_ids) > 0
|
||||
and previous_token_ids[-1] in self.parser_token_ids
|
||||
) and (
|
||||
len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id
|
||||
):
|
||||
content = content.lstrip("\n")
|
||||
|
||||
return DeltaMessage(content=content if content else None)
|
||||
logger.debug("cur_text = %s", cur_text)
|
||||
end_idx = cur_text.find(self.tool_call_end_token)
|
||||
if end_idx != -1:
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
extracted_tool_calls = self.extract_tool_calls(
|
||||
cur_text[: end_idx + len(self.tool_call_end_token)], request
|
||||
)
|
||||
|
||||
if len(extracted_tool_calls.tool_calls) == 0:
|
||||
logger.warning("Failed to extract any tool calls.")
|
||||
return None
|
||||
tool_call = extracted_tool_calls.tool_calls[0]
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
content=extracted_tool_calls.content,
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_id += 1
|
||||
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
|
||||
return delta
|
||||
|
||||
self._buffer = cur_text[start_idx:]
|
||||
content = cur_text[:start_idx].rstrip("\n")
|
||||
return DeltaMessage(content=content if content else None)
|
||||
190
vllm/tool_parsers/gigachat3_tool_parser.py
Normal file
190
vllm/tool_parsers/gigachat3_tool_parser.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
REGEX_FUNCTION_CALL = re.compile(
|
||||
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
NAME_REGEX = re.compile(
|
||||
r'"name"\s*:\s*"([^"]*)"',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
ARGS_REGEX = re.compile(
|
||||
r'"arguments"\s*:\s*(.*)',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
class GigaChat3ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.tool_started: bool = False
|
||||
self.tool_name_sent: bool = False
|
||||
self.tool_id: str | None = None
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.content_buffer: str = ""
|
||||
self.trigger_start = "function call{"
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
match = REGEX_FUNCTION_CALL.search(model_output)
|
||||
if not match:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
json_candidate = match.group(1).strip()
|
||||
try:
|
||||
data = json.loads(json_candidate)
|
||||
except json.JSONDecodeError:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
name = data["name"]
|
||||
args = data["arguments"]
|
||||
if not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
]
|
||||
prefix = model_output[: match.start()]
|
||||
content = prefix.rstrip() if prefix and prefix.strip() else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
func_name = None
|
||||
cur_args = None
|
||||
if not self.tool_started:
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if match:
|
||||
self.tool_started = True
|
||||
self.content_buffer = ""
|
||||
else:
|
||||
self.content_buffer += delta_text
|
||||
clean_buffer = self.content_buffer.lstrip()
|
||||
is_prefix = self.trigger_start.startswith(clean_buffer)
|
||||
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
|
||||
if is_prefix or starts_with_trigger:
|
||||
return None
|
||||
else:
|
||||
flush_text = self.content_buffer
|
||||
self.content_buffer = ""
|
||||
return DeltaMessage(content=flush_text)
|
||||
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if not match:
|
||||
return None
|
||||
json_tail = match.group(1).strip()
|
||||
name_match = NAME_REGEX.search(json_tail)
|
||||
if name_match:
|
||||
func_name = name_match.group(1)
|
||||
args_match = ARGS_REGEX.search(json_tail)
|
||||
if args_match:
|
||||
cur_args = args_match.group(1).strip()
|
||||
if cur_args.endswith("}"): # last '}' end of json
|
||||
try:
|
||||
candidate = cur_args[:-1].strip()
|
||||
json.loads(candidate)
|
||||
cur_args = candidate
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr.append({})
|
||||
if not self.tool_name_sent:
|
||||
if not func_name:
|
||||
return None
|
||||
self.tool_name_sent = True
|
||||
self.tool_id = make_tool_call_id()
|
||||
self.prev_tool_call_arr[0]["name"] = func_name
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id=self.tool_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
if cur_args is None:
|
||||
return None
|
||||
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
|
||||
if not prev_args:
|
||||
delta_args = cur_args
|
||||
elif cur_args.startswith(prev_args):
|
||||
delta_args = cur_args[len(prev_args) :]
|
||||
else:
|
||||
return None
|
||||
if not delta_args:
|
||||
return None
|
||||
self.prev_tool_call_arr[0]["arguments"] = cur_args
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_args,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
200
vllm/tool_parsers/glm4_moe_tool_parser.py
Normal file
200
vllm/tool_parsers/glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
self.func_detail_regex = re.compile(
|
||||
r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL
|
||||
)
|
||||
self.func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self._buffer = ""
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[ChatCompletionToolsParam] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if tool.function.name == tool_name:
|
||||
if tool.function.parameters is None:
|
||||
return False
|
||||
arg_type = (
|
||||
tool.function.parameters.get("properties", {})
|
||||
.get(arg_name, {})
|
||||
.get("type", None)
|
||||
)
|
||||
return arg_type == "string"
|
||||
logger.debug("No tool named '%s'.", tool_name)
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
matched_tool_calls = self.func_call_regex.findall(model_output)
|
||||
logger.debug("model_output: %s", model_output)
|
||||
try:
|
||||
tool_calls = []
|
||||
for match in matched_tool_calls:
|
||||
tc_detail = self.func_detail_regex.search(match)
|
||||
tc_name = tc_detail.group(1)
|
||||
tc_args = tc_detail.group(2)
|
||||
pairs = self.func_arg_regex.findall(tc_args)
|
||||
arg_dct = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = value.strip()
|
||||
if not _is_string_type(tc_name, arg_key, request.tools):
|
||||
arg_val = _deserialize(arg_val)
|
||||
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
|
||||
arg_dct[arg_key] = arg_val
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc_name, arguments=json.dumps(arg_dct)
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to extract tool call spec")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
else:
|
||||
if len(tool_calls) > 0:
|
||||
content = model_output[: model_output.find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
self._buffer += delta_text
|
||||
cur_text = self._buffer
|
||||
start_idx = cur_text.find(self.tool_call_start_token)
|
||||
if start_idx == -1:
|
||||
self._buffer = ""
|
||||
if self.current_tool_id > 0:
|
||||
cur_text = ""
|
||||
return DeltaMessage(content=cur_text)
|
||||
logger.debug("cur_text = %s", cur_text)
|
||||
end_idx = cur_text.find(self.tool_call_end_token)
|
||||
if end_idx != -1:
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
extracted_tool_calls = self.extract_tool_calls(
|
||||
cur_text[: end_idx + len(self.tool_call_end_token)], request
|
||||
)
|
||||
|
||||
if len(extracted_tool_calls.tool_calls) == 0:
|
||||
logger.warning("Failed to extract any tool calls.")
|
||||
return None
|
||||
tool_call = extracted_tool_calls.tool_calls[0]
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
self.streamed_args_for_tool[self.current_tool_id] = (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
content=extracted_tool_calls.content,
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_id += 1
|
||||
self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :]
|
||||
return delta
|
||||
|
||||
self._buffer = cur_text[start_idx:]
|
||||
return DeltaMessage(content=cur_text[:start_idx])
|
||||
273
vllm/tool_parsers/granite_20b_fc_tool_parser.py
Normal file
273
vllm/tool_parsers/granite_20b_fc_tool_parser.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecoder
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Granite20bFCToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite-20b-functioncalling model intended
|
||||
for use with the examples/tool_chat_template_granite20b_fc.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.bot_token = "<function_call>"
|
||||
self.tool_start_token = self.bot_token
|
||||
self.tool_call_regex = re.compile(r"<function_call>\s*")
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.tool_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
dec = JSONDecoder()
|
||||
try:
|
||||
matches = list(self.tool_call_regex.finditer(model_output))
|
||||
logger.debug("Found %d tool call matches", len(matches))
|
||||
|
||||
raw_function_calls = []
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
# position after the <function_call> tag
|
||||
start_of_json = match.end()
|
||||
# end_index == the start of the next function call
|
||||
# (if exists)
|
||||
next_function_call_start = (
|
||||
matches[i + 1].start() if i + 1 < len(matches) else None
|
||||
)
|
||||
|
||||
raw_function_calls.append(
|
||||
dec.raw_decode(
|
||||
model_output[start_of_json:next_function_call_start]
|
||||
)[0]
|
||||
)
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[: model_output.find(self.bot_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if len(current_text) < len(self.bot_token) and self.bot_token.startswith(
|
||||
current_text
|
||||
):
|
||||
return None
|
||||
|
||||
if not current_text.startswith(self.bot_token):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
start_idx = len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
start_idx += len(self.bot_token)
|
||||
start_idx = consume_space(start_idx, current_text)
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
253
vllm/tool_parsers/granite_tool_parser.py
Normal file
253
vllm/tool_parsers/granite_tool_parser.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
consume_space,
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GraniteToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for the granite 3.0 models. Intended
|
||||
for use with the examples/tool_chat_template_granite.jinja
|
||||
template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser granite
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
# for granite 3.0, the token `<|tool_call|>`
|
||||
self.bot_token = "<|tool_call|>"
|
||||
# for granite 3.1, the string `<tool_call>`
|
||||
self.bot_string = "<tool_call>"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
stripped = (
|
||||
model_output.strip()
|
||||
.removeprefix(self.bot_token)
|
||||
.removeprefix(self.bot_string)
|
||||
.lstrip()
|
||||
)
|
||||
if not stripped or stripped[0] != "[":
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
try:
|
||||
raw_function_calls = json.loads(stripped)
|
||||
if not isinstance(raw_function_calls, list):
|
||||
raise Exception(
|
||||
f"Expected dict or list, got {type(raw_function_calls)}"
|
||||
)
|
||||
|
||||
logger.debug("Extracted %d tool calls", len(raw_function_calls))
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s", e)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_token):
|
||||
start_idx = consume_space(start_idx + len(self.bot_token), current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string), current_text)
|
||||
if (
|
||||
not current_text
|
||||
or start_idx >= len(current_text)
|
||||
or current_text[start_idx] != "["
|
||||
):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = None
|
||||
is_complete = None
|
||||
try:
|
||||
tool_calls, end_idx = partial_json_loads(
|
||||
current_text[start_idx:], flags
|
||||
)
|
||||
if type(tool_calls) is list:
|
||||
tool_call_arr = tool_calls
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
is_complete = [True] * len(tool_calls)
|
||||
if not is_complete_json(current_text[start_idx : start_idx + end_idx]):
|
||||
is_complete[-1] = False
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if not tool_call_arr:
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = tool_call_arr[self.current_tool_id]
|
||||
|
||||
delta = None
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
if len(tool_call_arr) > self.current_tool_id + 1:
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
495
vllm/tool_parsers/hermes_tool_parser.py
Normal file
495
vllm/tool_parsers/hermes_tool_parser.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
|
||||
)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False
|
||||
)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False
|
||||
)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_start_token_ids
|
||||
]
|
||||
|
||||
self.tool_call_end_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_end_token_ids
|
||||
]
|
||||
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
# Very simple idea: when encountering tokens like <, tool, _call, >,
|
||||
# <, /, tool, _call, >, store them in a buffer.
|
||||
# When the last token is encountered, empty the buffer and return it.
|
||||
# If a token appears in an incorrect sequence while storing in the buffer,
|
||||
# return the preceding buffer along with the token.
|
||||
def tool_call_delta_buffer(self, delta_text: str):
|
||||
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
||||
# complete, fill the buffer with the token and return "".
|
||||
if (
|
||||
delta_text in self.tool_call_start_token_array
|
||||
or delta_text in self.tool_call_end_token_array
|
||||
):
|
||||
# If delta_text is the last token of tool_call_start_token or
|
||||
# tool_call_end_token, empty the buffer and return
|
||||
# the buffered text + delta_text.
|
||||
if (
|
||||
delta_text == self.tool_call_start_token_array[-1]
|
||||
or delta_text == self.tool_call_end_token_array[-1]
|
||||
):
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
self.buffered_delta_text = self.buffered_delta_text + delta_text
|
||||
return ""
|
||||
else:
|
||||
if self.buffered_delta_text:
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
return delta_text
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
# prior to the call to the tool parser, it breaks tool calling.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[: model_output.find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# 1. All tokens are parsed based on _text, not token_ids.
|
||||
# 2. All incoming text data is processed by the tool_call_delta_buffer
|
||||
# function for buffering before being used for parsing.
|
||||
|
||||
delta_text = self.tool_call_delta_buffer(delta_text)
|
||||
# If the last characters of previous_text
|
||||
# match self.buffered_delta_text, remove only the matching part.
|
||||
if (
|
||||
len(previous_text) >= len(self.buffered_delta_text)
|
||||
and previous_text[-len(self.buffered_delta_text) :]
|
||||
== self.buffered_delta_text
|
||||
):
|
||||
previous_text = previous_text[: -len(self.buffered_delta_text)]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token not in current_text:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_text.count(self.tool_call_start_token)
|
||||
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
||||
cur_tool_start_count = current_text.count(self.tool_call_start_token)
|
||||
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = (
|
||||
full_text.split(self.tool_call_start_token)[-1]
|
||||
.split(self.tool_call_end_token)[0]
|
||||
.rstrip()
|
||||
)
|
||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count
|
||||
):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count
|
||||
):
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
diff = (
|
||||
diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
current_tool_call = (
|
||||
partial_json_parser.loads(tool_call_portion or "{}", flags)
|
||||
if tool_call_portion
|
||||
else None
|
||||
)
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug("unable to parse JSON")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: str | None = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (
|
||||
DeltaMessage(content=delta_text)
|
||||
if text_portion is not None
|
||||
else None
|
||||
)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug(
|
||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
||||
)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything."
|
||||
)
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
# extract the content after {"name": ..., "arguments":
|
||||
# directly from tool_call_portion as cur_arguments_json,
|
||||
# since cur_arguments may differ from the original text
|
||||
# due to partial JSON parsing
|
||||
# for example, tool_call_portion =
|
||||
# {"name": "search", "arguments": {"search_request": {"
|
||||
# but cur_arguments =
|
||||
# {"search_request": {}}
|
||||
function_name = current_tool_call.get("name")
|
||||
match = re.search(
|
||||
r'\{"name":\s*"'
|
||||
+ re.escape(function_name)
|
||||
+ r'"\s*,\s*"arguments":\s*(.*)',
|
||||
tool_call_portion.strip(),
|
||||
re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
cur_arguments_json = match.group(1)
|
||||
else:
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
|
||||
logger.debug("finding %s in %s", delta_text, cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current.
|
||||
if delta_text not in cur_arguments_json:
|
||||
return None
|
||||
args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(
|
||||
delta_text
|
||||
)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s", arguments_delta)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
# judge whether the tool_call_portion is a complete JSON
|
||||
try:
|
||||
json.loads(tool_call_portion)
|
||||
is_complete_json = True
|
||||
except Exception:
|
||||
is_complete_json = False
|
||||
|
||||
# if the delta_text ends with a '}' and tool_call_portion is a
|
||||
# complete JSON, then the last '}' does not belong to the
|
||||
# arguments, so we should trim it off
|
||||
if (
|
||||
isinstance(delta_text, str)
|
||||
and len(delta_text.rstrip()) >= 1
|
||||
and delta_text.rstrip()[-1] == "}"
|
||||
and is_complete_json
|
||||
):
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += delta_text
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
420
vllm/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
420
vllm/tool_parsers/hunyuan_a13b_tool_parser.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501, SIM102
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import consume_space
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HunyuanA13BToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
self.prev_tool_calls: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args: list[str] = [] # Track arguments sent for each tool
|
||||
|
||||
# For backward compatibility with tests
|
||||
self.current_tools_sent: list[bool] = []
|
||||
|
||||
# For backward compatibility with serving code
|
||||
self.prev_tool_call_arr = []
|
||||
|
||||
# Regex patterns for preprocessing
|
||||
self.answer_tool_calls_pattern = re.compile(
|
||||
r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL
|
||||
)
|
||||
|
||||
self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
|
||||
|
||||
self.tool_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
|
||||
)
|
||||
|
||||
# TODO: not support nested json object in fc arguments.
|
||||
self.tool_non_empty_arg_reg = re.compile(
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
)
|
||||
|
||||
self.bot_string = "<tool_calls>"
|
||||
|
||||
# Define streaming state type to be initialized later
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def preprocess_model_output(
|
||||
self, model_output: str
|
||||
) -> tuple[str | None, str | None]:
|
||||
# find the location tool call
|
||||
for match in self.answer_tool_calls_pattern.finditer(model_output):
|
||||
start, end = match.span()
|
||||
# check tool_calls whether in side of <think>
|
||||
think_regions = [
|
||||
(m.start(), m.end())
|
||||
for m in re.finditer(
|
||||
r"<think>(.*?)</think>", model_output, flags=re.DOTALL
|
||||
)
|
||||
]
|
||||
in_think = any(
|
||||
start > t_start and end < t_end for t_start, t_end in think_regions
|
||||
)
|
||||
if not in_think:
|
||||
content = model_output[:start]
|
||||
tool_calls_content = match.group(1).strip()
|
||||
try:
|
||||
json.loads(tool_calls_content)
|
||||
return content, tool_calls_content
|
||||
except Exception:
|
||||
continue
|
||||
return model_output, None
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from a complete model output.
|
||||
"""
|
||||
try:
|
||||
# Preprocess the model output
|
||||
content, potential_tool_calls = self.preprocess_model_output(model_output)
|
||||
|
||||
if not potential_tool_calls:
|
||||
# some text should be filtered out for no function call
|
||||
# this text is in a13b's chat template.
|
||||
if content:
|
||||
content = content.replace("助手:", "", 1)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=content
|
||||
)
|
||||
|
||||
# Parse the potential tool calls as JSON
|
||||
tool_calls_data = json.loads(potential_tool_calls)
|
||||
|
||||
# Ensure it's an array
|
||||
if not isinstance(tool_calls_data, list):
|
||||
logger.debug("Tool calls data is not an array")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content or model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for idx, call in enumerate(tool_calls_data):
|
||||
if (
|
||||
not isinstance(call, dict)
|
||||
or "name" not in call
|
||||
or "arguments" not in call
|
||||
):
|
||||
continue
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{random_uuid()}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=call["name"],
|
||||
arguments=(
|
||||
json.dumps(call["arguments"])
|
||||
if isinstance(call["arguments"], dict)
|
||||
else call["arguments"]
|
||||
),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not content or len(content.strip()) == 0:
|
||||
# clear the whitespace content.
|
||||
content = None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
|
||||
start_idx = consume_space(0, current_text)
|
||||
if current_text[start_idx:].startswith(self.bot_string):
|
||||
start_idx = consume_space(start_idx + len(self.bot_string), current_text)
|
||||
if (
|
||||
not current_text
|
||||
or start_idx >= len(current_text)
|
||||
or current_text[start_idx] != "["
|
||||
):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
self._try_parse_json_tools(current_text[start_idx:])
|
||||
|
||||
test_delta = self._handle_test_compatibility(current_text)
|
||||
if test_delta:
|
||||
return test_delta
|
||||
|
||||
name_matches = list(self.tool_name_reg.finditer(current_text))
|
||||
tool_count = len(name_matches)
|
||||
if tool_count == 0:
|
||||
return None
|
||||
self._ensure_state_arrays(tool_count)
|
||||
current_idx = self.streaming_state["current_tool_index"]
|
||||
|
||||
name_delta = self._handle_tool_name_streaming(
|
||||
current_idx, tool_count, name_matches
|
||||
)
|
||||
if name_delta:
|
||||
return name_delta
|
||||
|
||||
args_delta = self._handle_tool_args_streaming(
|
||||
current_text, current_idx, tool_count
|
||||
)
|
||||
if args_delta:
|
||||
return args_delta
|
||||
|
||||
return None
|
||||
|
||||
def _try_parse_json_tools(self, current_text: str):
|
||||
try:
|
||||
parsed_tools = json.loads(current_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def _handle_test_compatibility(self, current_text: str):
|
||||
if len(self.current_tools_sent) > 0:
|
||||
if (
|
||||
len(self.current_tools_sent) == 1
|
||||
and self.current_tools_sent[0] is False
|
||||
):
|
||||
name_match = self.tool_name_reg.search(current_text)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
tool_id = f"chatcmpl-tool-{random_uuid()}"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tools_sent = [True]
|
||||
self.current_tool_id = 0
|
||||
self.streaming_state["current_tool_index"] = 0
|
||||
if len(self.streaming_state["sent_tools"]) == 0:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": True,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.streaming_state["sent_tools"][0]["sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _ensure_state_arrays(self, tool_count: int):
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": False,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
while len(self.streaming_state["tool_ids"]) < tool_count:
|
||||
self.streaming_state["tool_ids"].append(None)
|
||||
|
||||
def _handle_tool_name_streaming(
|
||||
self, current_idx: int, tool_count: int, name_matches
|
||||
):
|
||||
if current_idx == -1 or current_idx < tool_count - 1:
|
||||
next_idx = current_idx + 1
|
||||
if (
|
||||
next_idx < tool_count
|
||||
and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
|
||||
):
|
||||
self.streaming_state["current_tool_index"] = next_idx
|
||||
self.current_tool_id = next_idx
|
||||
current_idx = next_idx
|
||||
tool_name = name_matches[current_idx].group(1)
|
||||
tool_id = f"call_{current_idx}_{random_uuid()}"
|
||||
self.streaming_state["tool_ids"][current_idx] = tool_id
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(name=tool_name).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
return delta
|
||||
return None
|
||||
|
||||
def _handle_tool_args_streaming(
|
||||
self, current_text: str, current_idx: int, tool_count: int
|
||||
):
|
||||
if current_idx >= 0 and current_idx < tool_count:
|
||||
empty_args_match = self.tool_empty_arg_reg.search(current_text)
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
for i in range(tool_count):
|
||||
if i == current_idx:
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
]:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{}"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{}"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{}"
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
]
|
||||
return delta
|
||||
|
||||
args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text))
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
if not is_last_tool:
|
||||
next_tool_pos = current_text.find(
|
||||
"},{", args_matches[current_idx].start()
|
||||
)
|
||||
if next_tool_pos != -1:
|
||||
args_end_pos = next_tool_pos + 1
|
||||
args_text = (
|
||||
current_text[
|
||||
args_matches[current_idx].start() : args_end_pos
|
||||
]
|
||||
.split('"arguments":')[1]
|
||||
.strip()
|
||||
)
|
||||
sent_args = self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
]
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] and args_text.startswith("{"):
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{"
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{"
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(arguments="{").model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
if args_text.startswith(sent_args):
|
||||
args_diff = args_text[len(sent_args) :]
|
||||
if args_diff:
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = args_text
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += args_diff
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
if args_text.endswith("}") and args_text == sent_args:
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
]
|
||||
return None
|
||||
227
vllm/tool_parsers/internlm2_tool_parser.py
Normal file
227
vllm/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import extract_intermediate_diff
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_arguments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if "<|action_start|>" not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sent, return an empty delta message
|
||||
# to make sure the finish_reason will be sent correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content="")
|
||||
|
||||
last_pos = self.position
|
||||
if "<|action_start|><|plugin|>" not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split("<|action_start|><|plugin|>")
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split("<|action_end|>".strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in internlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_arguments(
|
||||
self.prev_tool_call_arr[self.current_tool_id]
|
||||
)
|
||||
cur_arguments = self.get_arguments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset mid-arguments"
|
||||
)
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
|
||||
arguments_delta = cur_arguments_json[
|
||||
: cur_arguments_json.index(delta_text) + len(delta_text)
|
||||
]
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json
|
||||
)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_arguments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if "<|action_start|><|plugin|>" in text:
|
||||
text, action = text.split("<|action_start|><|plugin|>")
|
||||
action = action.split("<|action_end|>".strip())[0]
|
||||
action = action[action.find("{") :]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = (
|
||||
action_dict["name"],
|
||||
json.dumps(
|
||||
action_dict.get("parameters", action_dict.get("arguments", {})),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=text
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None,
|
||||
)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=text
|
||||
)
|
||||
323
vllm/tool_parsers/jamba_tool_parser.py
Normal file
323
vllm/tool_parsers/jamba_tool_parser.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.tool_parsers.utils import extract_intermediate_diff
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JambaToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
||||
)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<tool_calls>"
|
||||
self.tool_calls_end_token: str = "</tool_calls>"
|
||||
|
||||
self.tool_calls_regex = re.compile(
|
||||
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
|
||||
if (
|
||||
self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Jamba Tool parser could not locate tool calls start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because jamba use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# use a regex to find the tool call between the tags
|
||||
function_calls = self.tool_calls_regex.findall(model_output)[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = json.loads(function_calls)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[: model_output.find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if (len(content) > 0 and content != " ") else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.tool_calls_start_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the start of tool calls token which means
|
||||
# the start of tool calling
|
||||
if (
|
||||
self.tool_calls_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1
|
||||
):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion and don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
# Extract the tool calls between the special tool call tokens
|
||||
parsable_arr = current_text.split(self.tool_calls_start_token)[-1].split(
|
||||
self.tool_calls_end_token
|
||||
)[0]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: list[dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags
|
||||
)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: str | None = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], ""
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("'", '"')
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset mid-arguments"
|
||||
)
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
logger.debug("finding %s in %s", new_text, cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[
|
||||
: cur_arguments_json.index(new_text) + len(new_text)
|
||||
]
|
||||
logger.debug(
|
||||
"First tokens in arguments received: %s", arguments_delta
|
||||
)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Searching for diff between \n%s\n%s",
|
||||
cur_args_json,
|
||||
prev_args_json,
|
||||
)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json
|
||||
)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
590
vllm/tool_parsers/kimi_k2_tool_parser.py
Normal file
590
vllm/tool_parsers/kimi_k2_tool_parser.py
Normal file
@@ -0,0 +1,590 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# code modified from deepseekv3_tool_parser.py
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
# Section-level state management to prevent token leakage
|
||||
self.in_tool_section: bool = False
|
||||
self.token_buffer: str = ""
|
||||
# Buffer size: empirical worst-case for longest marker (~30 chars) * 2
|
||||
# + safety margin for unicode + partial overlap. Prevents unbounded growth.
|
||||
self.buffer_max_size: int = 1024
|
||||
self.section_char_count: int = 0 # Track characters processed in tool section
|
||||
self.max_section_chars: int = 8192 # Force exit if section exceeds this
|
||||
self._buffer_overflow_logged: bool = False # Log overflow once per session
|
||||
|
||||
# Support both singular and plural variants
|
||||
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
||||
self.tool_calls_start_token_variants: list[str] = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>", # singular variant
|
||||
]
|
||||
self.tool_calls_end_token_variants: list[str] = [
|
||||
"<|tool_calls_section_end|>",
|
||||
"<|tool_call_section_end|>", # singular variant
|
||||
]
|
||||
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)"
|
||||
)
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
|
||||
|
||||
# Get token IDs for all variants
|
||||
self.tool_calls_start_token_ids: list[int] = [
|
||||
tid
|
||||
for variant in self.tool_calls_start_token_variants
|
||||
if (tid := self.vocab.get(variant)) is not None
|
||||
]
|
||||
self.tool_calls_end_token_ids: list[int] = [
|
||||
tid
|
||||
for variant in self.tool_calls_end_token_variants
|
||||
if (tid := self.vocab.get(variant)) is not None
|
||||
]
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (
|
||||
self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Kimi-K2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Check for section begin/end markers in text and strip them.
|
||||
Returns: (cleaned_text, found_section_begin, found_section_end)
|
||||
"""
|
||||
found_begin = False
|
||||
found_end = False
|
||||
cleaned = text
|
||||
|
||||
# Check for section begin markers (any variant)
|
||||
for variant in self.tool_calls_start_token_variants:
|
||||
if variant in cleaned:
|
||||
cleaned = cleaned.replace(variant, "")
|
||||
found_begin = True
|
||||
|
||||
# Check for section end markers (any variant)
|
||||
for variant in self.tool_calls_end_token_variants:
|
||||
if variant in cleaned:
|
||||
cleaned = cleaned.replace(variant, "")
|
||||
found_end = True
|
||||
|
||||
return cleaned, found_begin, found_end
|
||||
|
||||
def _reset_section_state(self) -> None:
|
||||
"""Reset state when exiting tool section."""
|
||||
self.in_tool_section = False
|
||||
self.token_buffer = ""
|
||||
self.section_char_count = 0
|
||||
|
||||
def reset_streaming_state(self) -> None:
|
||||
"""
|
||||
Reset all streaming state. Call this between requests to prevent
|
||||
state leakage when parser instance is reused.
|
||||
"""
|
||||
# Reset section state
|
||||
self._reset_section_state()
|
||||
|
||||
# Reset parent class state
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool = []
|
||||
|
||||
logger.debug("Streaming state reset")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
logger.debug("function_call_tuples: %s", function_call_tuples)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_id, function_args = match
|
||||
# function_id: functions.get_weather:0 or get_weather:0
|
||||
function_name = function_id.split(":")[0].split(".")[-1]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=function_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=function_args
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content = model_output[: model_output.find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
|
||||
# Flag to defer section exit until after tool parsing completes
|
||||
deferred_section_exit = False
|
||||
|
||||
# Add delta to buffer for split marker detection
|
||||
self.token_buffer += delta_text
|
||||
|
||||
# Enforce buffer size limit to prevent memory issues
|
||||
if len(self.token_buffer) > self.buffer_max_size:
|
||||
if not self._buffer_overflow_logged:
|
||||
logger.warning(
|
||||
"Token buffer exceeded max size (%d bytes), flushing excess. "
|
||||
"This may indicate very long markers or unusual tokenization.",
|
||||
self.buffer_max_size,
|
||||
)
|
||||
self._buffer_overflow_logged = True
|
||||
# Keep only the most recent content that might contain partial markers
|
||||
self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
|
||||
|
||||
# Check buffer for section markers (handles split tokens)
|
||||
buffered_text, found_section_begin, found_section_end = (
|
||||
self._check_and_strip_markers(self.token_buffer)
|
||||
)
|
||||
|
||||
# Track section state transitions
|
||||
if found_section_begin and not self.in_tool_section:
|
||||
logger.debug("Entering tool section")
|
||||
self.in_tool_section = True
|
||||
self.token_buffer = buffered_text # Use cleaned buffer
|
||||
self.section_char_count = 0 # Reset counter for new section
|
||||
if found_section_end and self.in_tool_section:
|
||||
logger.debug("Detected section end marker")
|
||||
# CRITICAL: Don't exit early if tool_call_end is in this chunk.
|
||||
# Tool parser must emit final arguments/close first to avoid dropping
|
||||
# the final tool update and leaking tokens into reasoning channel.
|
||||
has_tool_end = self.tool_call_end_token_id in delta_token_ids
|
||||
if has_tool_end:
|
||||
# Defer exit until after tool parsing completes
|
||||
deferred_section_exit = True
|
||||
logger.debug("Deferring section exit: tool_call_end in same chunk")
|
||||
self.token_buffer = buffered_text
|
||||
else:
|
||||
# No tool call ending, safe to exit immediately
|
||||
logger.debug("Exiting tool section")
|
||||
remaining = buffered_text
|
||||
self._reset_section_state()
|
||||
# Return remaining text as reasoning content if non-empty
|
||||
if remaining.strip():
|
||||
return DeltaMessage(content=remaining)
|
||||
# Return empty delta to maintain function contract
|
||||
# (always returns DeltaMessage)
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
self.token_buffer = buffered_text
|
||||
|
||||
# Check if any variant of section start token is in current_token_ids
|
||||
has_section_token = any(
|
||||
tid in current_token_ids for tid in self.tool_calls_start_token_ids
|
||||
)
|
||||
|
||||
# Early return: if no section token detected yet, return as reasoning content
|
||||
if not has_section_token and not self.in_tool_section:
|
||||
logger.debug("No tool call tokens found!")
|
||||
# Don't clear buffer - it needs to accumulate partial markers across deltas
|
||||
# Buffer overflow is already protected by lines 215-224
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Strip section markers from delta_text for subsequent processing
|
||||
# NOTE: This preprocessing happens BEFORE the regex-based tool call
|
||||
# parsing (from PR #24847) to ensure markers are removed cleanly
|
||||
# before pattern matching. No double-stripping occurs because
|
||||
# section markers and tool call markers are distinct.
|
||||
delta_text, _, _ = self._check_and_strip_markers(delta_text)
|
||||
|
||||
# Error recovery: If in tool section for too long, force exit
|
||||
if self.in_tool_section:
|
||||
self.section_char_count += len(delta_text)
|
||||
if self.section_char_count > self.max_section_chars:
|
||||
logger.warning(
|
||||
"Tool section exceeded max length (%d chars), forcing exit. "
|
||||
"This may indicate malformed model output.",
|
||||
self.max_section_chars,
|
||||
)
|
||||
self._reset_section_state()
|
||||
# Deferred exit already handled by forced exit above
|
||||
# Return remaining content as reasoning (or empty delta if no content)
|
||||
return DeltaMessage(content=delta_text if delta_text.strip() else "")
|
||||
|
||||
try:
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id
|
||||
)
|
||||
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
# CRITICAL FIX: Suppress content if in tool section but
|
||||
# no tool calls started
|
||||
if self.in_tool_section and cur_tool_start_count == 0:
|
||||
logger.debug(
|
||||
"In tool section but no tool calls started yet. "
|
||||
"Suppressing: %s",
|
||||
delta_text,
|
||||
)
|
||||
# Return empty delta to maintain iterator contract
|
||||
return DeltaMessage(content="")
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = (
|
||||
full_text.split(self.tool_call_start_token)[-1]
|
||||
.split(self.tool_call_end_token)[0]
|
||||
.rstrip()
|
||||
)
|
||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count
|
||||
):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count
|
||||
):
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
diff = (
|
||||
diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
logger.debug("Completing deferred section exit")
|
||||
self._reset_section_state()
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
# Check if we're in tool section - if so, suppress
|
||||
if self.in_tool_section:
|
||||
logger.debug("In tool section, suppressing text generation")
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit:
|
||||
self._reset_section_state()
|
||||
return DeltaMessage(content="")
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
# Handle deferred section exit before returning
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
self._reset_section_state()
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion
|
||||
)
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = current_tool_call_matches.groups()
|
||||
tool_name = tool_id.split(":")[0].split(".")[-1]
|
||||
current_tool_call["id"] = tool_id
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(tool_call_portion)
|
||||
)
|
||||
if current_tool_call_name_matches:
|
||||
(tool_id_str,) = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split(":")[0].split(".")[-1]
|
||||
current_tool_call["id"] = tool_id_str
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: str | None = current_tool_call.get("name")
|
||||
tool_id = current_tool_call.get("id")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (
|
||||
DeltaMessage(content=delta_text)
|
||||
if text_portion is not None
|
||||
else None
|
||||
)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug(
|
||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
||||
)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything."
|
||||
)
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (
|
||||
isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)
|
||||
):
|
||||
delta_arguments = cur_arguments[len(prev_arguments) :]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
# Handle deferred section exit after tool parsing completes
|
||||
if deferred_section_exit and self.in_tool_section:
|
||||
logger.debug("Completing deferred section exit")
|
||||
self._reset_section_state()
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
341
vllm/tool_parsers/llama4_pythonic_tool_parser.py
Normal file
341
vllm/tool_parsers/llama4_pythonic_tool_parser.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Llama4PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Toolcall parser for Llama4 that produce tool calls in a pythonic style
|
||||
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
|
||||
"""
|
||||
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
# as Llama 4 model sometime will output those tokens
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>") :]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = (
|
||||
self.TOOL_CALL_REGEX.match(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
)
|
||||
is not None
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if not current_text.startswith("[") and not current_text.startswith(
|
||||
"<|python_start|>"
|
||||
):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
if current_text.startswith("<|python_start|>"):
|
||||
current_text = current_text[len("<|python_start|>") :]
|
||||
if current_text.endswith("<|python_end|>"):
|
||||
current_text = current_text[: current_text.rfind("<|python_end|>")]
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = (
|
||||
index < len(tool_calls) - 1 or ")]" not in added_text
|
||||
)
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = added_text[:-2] if not new_call_complete else ""
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (
|
||||
delta.function is not None
|
||||
and delta.function.arguments is not None
|
||||
):
|
||||
self.streamed_args_for_tool[index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
324
vllm/tool_parsers/llama_tool_parser.py
Normal file
324
vllm/tool_parsers/llama_tool_parser.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.x and 4 models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
|
||||
llama4_json are set.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "<|python_tag|>"
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
# Simple regex to find opening braces - we'll use JSON decoder for parsing
|
||||
# This handles arbitrary nesting depth correctly
|
||||
self.tool_call_start_regex = re.compile(r"\{")
|
||||
self.json_decoder = json.JSONDecoder()
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
Only extracts JSON content and ignores any surrounding plain text.
|
||||
Supports both single JSON and multiple JSONs separated by semicolons.
|
||||
"""
|
||||
# Quick check before running regex
|
||||
if not (self.bot_token in model_output or "{" in model_output):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Keep track of the end index of the last parsed JSON object
|
||||
# so we don't parse inner brackets
|
||||
end_index = -1
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
try:
|
||||
for match in self.tool_call_start_regex.finditer(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
):
|
||||
start_index = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start_index <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end_index = self.json_decoder.raw_decode(
|
||||
model_output[start_index:]
|
||||
)
|
||||
end_index = start_index + json_end_index
|
||||
|
||||
# raise KeyError if missing
|
||||
name = obj["name"]
|
||||
arguments_or_params = (
|
||||
obj["arguments"] if "arguments" in obj else obj["parameters"]
|
||||
)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
arguments_or_params, ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
# Missing required key
|
||||
missing_key = str(e).strip("'\"")
|
||||
logger.exception(
|
||||
"Couldn't extract tool call from JSON response. "
|
||||
"Required key '%s' not present. "
|
||||
"Returning output in content with empty tool calls.",
|
||||
missing_key,
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
except Exception:
|
||||
# Any other error during parsing
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response. "
|
||||
"Returning output in content with empty tool calls"
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# If we have valid tool calls, return them normally
|
||||
if tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=None
|
||||
)
|
||||
|
||||
# No valid tool calls found
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if not (
|
||||
current_text.startswith(self.bot_token) or current_text.startswith("{")
|
||||
):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = (
|
||||
len(self.bot_token)
|
||||
if current_text.startswith(self.bot_token)
|
||||
else 0
|
||||
)
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx + len("; ")
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert "arguments" not in obj, (
|
||||
"model generated both parameters and arguments"
|
||||
)
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
|
||||
if cur_args_json != prev_args_json:
|
||||
prefix = find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += (
|
||||
argument_diff
|
||||
)
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
37
vllm/tool_parsers/longcat_tool_parser.py
Normal file
37
vllm/tool_parsers/longcat_tool_parser.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
|
||||
|
||||
class LongcatFlashToolParser(Hermes2ProToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.tool_call_start_token: str = "<longcat_tool_call>"
|
||||
self.tool_call_end_token: str = "</longcat_tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False
|
||||
)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False
|
||||
)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_start_token_ids
|
||||
]
|
||||
|
||||
self.tool_call_end_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_end_token_ids
|
||||
]
|
||||
643
vllm/tool_parsers/minimax_m2_tool_parser.py
Normal file
643
vllm/tool_parsers/minimax_m2_tool_parser.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MinimaxM2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
# Sentinel tokens
|
||||
self.tool_call_start_token: str = "<minimax:tool_call>"
|
||||
self.tool_call_end_token: str = "</minimax:tool_call>"
|
||||
self.invoke_start_prefix: str = "<invoke name="
|
||||
self.invoke_end_token: str = "</invoke>"
|
||||
self.parameter_prefix: str = "<parameter name="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
|
||||
# Streaming state variables
|
||||
self.current_tool_name_sent: bool = False
|
||||
# Override base class type - we use string IDs for tool calls
|
||||
self.current_tool_id: str | None = None # type: ignore
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.is_tool_call_started: bool = False
|
||||
self.failed_count: int = 0
|
||||
|
||||
# Initialize streaming state variables
|
||||
self.current_tool_index: int = 0
|
||||
self.invoke_index: int = 0
|
||||
self.header_sent: bool = False
|
||||
self.current_function_name: str | None = None
|
||||
self.current_param_name: str | None = None
|
||||
self.current_param_value: str = ""
|
||||
self.param_count: int = 0
|
||||
self.in_param: bool = False
|
||||
self.in_function: bool = False
|
||||
self.accumulated_text: str = ""
|
||||
self.json_started: bool = False
|
||||
self.json_closed: bool = False
|
||||
self.accumulated_params: dict = {}
|
||||
self.streaming_request: ChatCompletionRequest | None = None
|
||||
|
||||
# Enhanced streaming state - reset for each new message
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Regex patterns for complete parsing
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
|
||||
)
|
||||
self.invoke_complete_regex = re.compile(
|
||||
r"<invoke name=(.*?)</invoke>", re.DOTALL
|
||||
)
|
||||
self.parameter_complete_regex = re.compile(
|
||||
r"<parameter name=(.*?)</parameter>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"MiniMax M2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.invoke_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
# Store accumulated parameters for type conversion
|
||||
self.accumulated_params = {}
|
||||
self.streaming_request = None
|
||||
# Clear previous tool call history to avoid state pollution
|
||||
self.prev_tool_call_arr.clear()
|
||||
|
||||
def _extract_name(self, name_str: str) -> str:
|
||||
"""Extract name from quoted string."""
|
||||
name_str = name_str.strip()
|
||||
if (
|
||||
name_str.startswith('"')
|
||||
and name_str.endswith('"')
|
||||
or name_str.startswith("'")
|
||||
and name_str.endswith("'")
|
||||
):
|
||||
return name_str[1:-1]
|
||||
return name_str
|
||||
|
||||
def _convert_param_value(self, value: str, param_type: str) -> Any:
|
||||
"""Convert parameter value to the correct type."""
|
||||
if value.lower() == "null":
|
||||
return None
|
||||
|
||||
param_type = param_type.lower()
|
||||
if param_type in ["string", "str", "text"]:
|
||||
return value
|
||||
elif param_type in ["integer", "int"]:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["number", "float"]:
|
||||
try:
|
||||
val = float(value)
|
||||
return val if val != int(val) else int(val)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
elif param_type in ["boolean", "bool"]:
|
||||
return value.lower() in ["true", "1"]
|
||||
elif param_type in ["object", "array"]:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
else:
|
||||
# Try JSON parse first, fallback to string
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def _parse_single_invoke(
|
||||
self, invoke_str: str, tools: list | None
|
||||
) -> ToolCall | None:
|
||||
"""Parse a single <invoke> block."""
|
||||
# Extract function name
|
||||
name_match = re.search(r"^([^>]+)", invoke_str)
|
||||
if not name_match:
|
||||
return None
|
||||
|
||||
function_name = self._extract_name(name_match.group(1))
|
||||
|
||||
# Get parameter configuration
|
||||
param_config = {}
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
params = tool.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Extract parameters
|
||||
param_dict = {}
|
||||
for match in self.parameter_complete_regex.findall(invoke_str):
|
||||
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
|
||||
if param_match:
|
||||
param_name = self._extract_name(param_match.group(1))
|
||||
param_value = param_match.group(2).strip()
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
param_name in param_config
|
||||
and isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]
|
||||
):
|
||||
param_type = param_config[param_name]["type"]
|
||||
|
||||
# Convert value
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
# Quick check
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
# Find all complete tool_call blocks
|
||||
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
|
||||
# Find all invokes within this tool_call
|
||||
for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
|
||||
tool_call = self._parse_single_invoke(
|
||||
invoke_match, request.tools if request else None
|
||||
)
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Update prev_tool_call_arr
|
||||
self.prev_tool_call_arr.clear()
|
||||
for tool_call in tool_calls:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract content before first tool call
|
||||
first_tool_idx = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output."""
|
||||
|
||||
# Store request for type conversion
|
||||
if not previous_text or self.tool_call_start_token in delta_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
# If no delta text, return None unless it's an EOS token after tools
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text)
|
||||
)
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token
|
||||
) - current_text.count(self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta for finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
invoke_ends = current_text.count(self.invoke_end_token)
|
||||
if invoke_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.in_function = False # Now we can safely set this to False
|
||||
self.accumulated_params = {}
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (
|
||||
self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text
|
||||
):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[
|
||||
: delta_text.index(self.tool_call_start_token)
|
||||
]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (
|
||||
current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""
|
||||
):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
invoke_starts_count = current_text.count(self.invoke_start_prefix)
|
||||
if self.current_tool_index >= invoke_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# Find the current tool call portion
|
||||
invoke_start_positions: list[int] = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = current_text.find(self.invoke_start_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
invoke_start_positions.append(idx)
|
||||
idx += len(self.invoke_start_prefix)
|
||||
|
||||
if self.current_tool_index >= len(invoke_start_positions):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
invoke_start_idx = invoke_start_positions[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
|
||||
if invoke_end_idx == -1:
|
||||
tool_text = current_text[invoke_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[
|
||||
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
|
||||
]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.invoke_start_prefix in tool_text:
|
||||
func_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
# Find the end quote for the function name
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
function_name_raw = tool_text[func_start:func_end]
|
||||
self.current_function_name = self._extract_name(function_name_raw)
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# Each tool call should be recorded regardless of function name
|
||||
# Ensure we don't add the same tool call index multiple times
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_index:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}", # Placeholder, will be updated later
|
||||
}
|
||||
)
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if self.in_function and not self.json_started:
|
||||
self.json_started = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.invoke_end_token in tool_text:
|
||||
# Count total parameters in the tool text
|
||||
total_param_count = tool_text.count(self.parameter_prefix)
|
||||
|
||||
# Only close JSON if all parameters have been processed
|
||||
if self.param_count >= total_param_count:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract complete tool call
|
||||
# Find the invoke content
|
||||
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
|
||||
self.invoke_start_prefix
|
||||
)
|
||||
invoke_content_end = tool_text.find(
|
||||
self.invoke_end_token, invoke_start
|
||||
)
|
||||
if invoke_content_end != -1:
|
||||
invoke_content = tool_text[invoke_start:invoke_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_single_invoke(
|
||||
invoke_content,
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request
|
||||
else None,
|
||||
)
|
||||
if parsed_tool and self.current_tool_index < len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
# Update existing entry in prev_tool_call_arr
|
||||
args = parsed_tool.function.arguments
|
||||
self.prev_tool_call_arr[self.current_tool_index][
|
||||
"arguments"
|
||||
] = args
|
||||
except Exception:
|
||||
pass # Ignore parsing errors during streaming
|
||||
|
||||
result = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Reset state for next tool
|
||||
self.json_closed = True
|
||||
self.in_function = False
|
||||
self.accumulated_params = {}
|
||||
|
||||
logger.debug("[M2_STREAMING] Tool call completed")
|
||||
|
||||
return result
|
||||
else:
|
||||
# Don't close JSON yet, continue processing parameters
|
||||
return None
|
||||
|
||||
# Look for parameters
|
||||
# Find all parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if (
|
||||
not self.in_param
|
||||
and self.param_count < len(param_starts)
|
||||
and len(param_starts) > self.param_count
|
||||
):
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
param_name_raw = remaining[:name_end]
|
||||
self.current_param_name = self._extract_name(param_name_raw)
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
# No closing tag, look for next parameter or function end
|
||||
next_param_idx = value_text.find(self.parameter_prefix)
|
||||
func_end_idx = value_text.find(self.invoke_end_token)
|
||||
|
||||
if next_param_idx != -1 and (
|
||||
func_end_idx == -1 or next_param_idx < func_end_idx
|
||||
):
|
||||
param_end_idx = next_param_idx
|
||||
elif func_end_idx != -1:
|
||||
param_end_idx = func_end_idx
|
||||
else:
|
||||
# Neither found, check if tool call is complete
|
||||
if self.invoke_end_token in tool_text:
|
||||
# Tool call and parameter is complete
|
||||
param_end_idx = len(value_text)
|
||||
else:
|
||||
# Still streaming, wait for more content
|
||||
return None
|
||||
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = {}
|
||||
if self.streaming_request and self.streaming_request.tools:
|
||||
for tool in self.streaming_request.tools:
|
||||
if (
|
||||
hasattr(tool, "function")
|
||||
and tool.function.name == self.current_function_name
|
||||
and hasattr(tool.function, "parameters")
|
||||
):
|
||||
params = tool.function.parameters
|
||||
if (
|
||||
isinstance(params, dict)
|
||||
and "properties" in params
|
||||
):
|
||||
param_config = params["properties"]
|
||||
break
|
||||
|
||||
# Get parameter type
|
||||
param_type = "string"
|
||||
if (
|
||||
self.current_param_name in param_config
|
||||
and isinstance(param_config[self.current_param_name], dict)
|
||||
and "type" in param_config[self.current_param_name]
|
||||
):
|
||||
param_type = param_config[self.current_param_name]["type"]
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
param_value, param_type
|
||||
)
|
||||
|
||||
# Build JSON fragment based on the converted type
|
||||
# Use json.dumps to properly serialize the value
|
||||
serialized_value = json.dumps(
|
||||
converted_value, ensure_ascii=False
|
||||
)
|
||||
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
f'"{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
else:
|
||||
json_fragment = (
|
||||
f', "{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments=json_fragment),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
849
vllm/tool_parsers/minimax_tool_parser.py
Normal file
849
vllm/tool_parsers/minimax_tool_parser.py
Normal file
@@ -0,0 +1,849 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import extract_intermediate_diff
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MinimaxToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize streaming state for tracking tool call progress
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1, # Index of current tool being processed
|
||||
"tool_ids": [], # List of tool call IDs
|
||||
"sent_tools": [], # List of tools that have been sent
|
||||
}
|
||||
|
||||
# Define tool call tokens and patterns
|
||||
self.tool_call_start_token = "<tool_calls>"
|
||||
self.tool_call_end_token = "</tool_calls>"
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL
|
||||
)
|
||||
self.thinking_tag_pattern = r"<think>(.*?)</think>"
|
||||
self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
|
||||
self.tool_args_pattern = re.compile(r'"arguments":\s*')
|
||||
|
||||
# Buffer for handling partial tool calls during streaming
|
||||
self.pending_buffer = ""
|
||||
self.in_thinking_tag = False
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Get token IDs for tool call start/end tokens
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
logger.warning(
|
||||
"Minimax Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer. Falling back to string matching."
|
||||
)
|
||||
|
||||
def preprocess_model_output(self, model_output: str) -> str:
|
||||
"""
|
||||
Preprocess model output by removing tool calls from thinking tags.
|
||||
|
||||
Args:
|
||||
model_output: Raw model output string
|
||||
|
||||
Returns:
|
||||
Preprocessed model output with tool calls removed from thinking tags
|
||||
"""
|
||||
|
||||
def remove_tool_calls_from_think(match):
|
||||
think_content = match.group(1)
|
||||
cleaned_content = re.sub(
|
||||
r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL
|
||||
)
|
||||
return f"<think>{cleaned_content}</think>"
|
||||
|
||||
return re.sub(
|
||||
self.thinking_tag_pattern,
|
||||
remove_tool_calls_from_think,
|
||||
model_output,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
def _clean_duplicate_braces(self, args_text: str) -> str:
|
||||
"""
|
||||
Clean duplicate closing braces from arguments text.
|
||||
|
||||
Args:
|
||||
args_text: Raw arguments text
|
||||
|
||||
Returns:
|
||||
Cleaned arguments text with proper JSON formatting
|
||||
"""
|
||||
args_text = args_text.strip()
|
||||
if not args_text:
|
||||
return args_text
|
||||
|
||||
try:
|
||||
json.loads(args_text)
|
||||
return args_text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
while args_text.endswith("}}"):
|
||||
candidate = args_text[:-1]
|
||||
try:
|
||||
json.loads(candidate)
|
||||
return candidate
|
||||
except json.JSONDecodeError:
|
||||
args_text = candidate
|
||||
|
||||
return args_text
|
||||
|
||||
def _clean_delta_braces(self, delta_text: str) -> str:
|
||||
"""
|
||||
Clean delta text by removing excessive closing braces.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned delta text
|
||||
"""
|
||||
if not delta_text:
|
||||
return delta_text
|
||||
|
||||
delta_stripped = delta_text.strip()
|
||||
|
||||
if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped):
|
||||
brace_count = delta_stripped.count("}")
|
||||
if brace_count > 1:
|
||||
return "}\n" if delta_text.endswith("\n") else "}"
|
||||
|
||||
return delta_text
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from model output for non-streaming mode.
|
||||
|
||||
Args:
|
||||
model_output: Complete model output
|
||||
request: Chat completion request
|
||||
|
||||
Returns:
|
||||
ExtractedToolCallInformation containing tool calls and content
|
||||
"""
|
||||
processed_output = self.preprocess_model_output(model_output)
|
||||
|
||||
if self.tool_call_start_token not in processed_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
function_call_tuples = self.tool_call_regex.findall(processed_output)
|
||||
|
||||
raw_function_calls = []
|
||||
for match in function_call_tuples:
|
||||
tool_call_content = match[0] if match[0] else match[1]
|
||||
if tool_call_content.strip():
|
||||
lines = tool_call_content.strip().split("\n")
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
parsed_call = json.loads(line)
|
||||
raw_function_calls.append(parsed_call)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
tool_calls = []
|
||||
for function_call in raw_function_calls:
|
||||
if "name" in function_call and "arguments" in function_call:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
processed_pos = processed_output.find(self.tool_call_start_token)
|
||||
if processed_pos != -1:
|
||||
processed_content = processed_output[:processed_pos].strip()
|
||||
|
||||
if processed_content:
|
||||
lines = processed_content.split("\n")
|
||||
for line in reversed(lines):
|
||||
line = line.strip()
|
||||
if line:
|
||||
pos = model_output.find(line)
|
||||
if pos != -1:
|
||||
content = model_output[: pos + len(line)]
|
||||
break
|
||||
else:
|
||||
content = ""
|
||||
else:
|
||||
content = ""
|
||||
else:
|
||||
content = model_output
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content.strip() if content.strip() else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"An unexpected error occurred during tool call extraction."
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _update_thinking_state(self, text: str) -> None:
|
||||
"""
|
||||
Update the thinking tag state based on text content.
|
||||
|
||||
Args:
|
||||
text: Text to analyze for thinking tags
|
||||
"""
|
||||
open_count = text.count("<think>")
|
||||
close_count = text.count("</think>")
|
||||
self.in_thinking_tag = open_count > close_count or (
|
||||
open_count == close_count and text.endswith("</think>")
|
||||
)
|
||||
|
||||
def _is_potential_tag_start(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text might be the start of a tool call tag.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text could be the start of a tool call tag
|
||||
"""
|
||||
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||
if any(
|
||||
tag.startswith(text[-i:])
|
||||
for i in range(1, min(len(text) + 1, len(tag)))
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_buffer_content(self, delta_text: str) -> bool:
|
||||
"""
|
||||
Determine if content should be buffered for later processing.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to check
|
||||
|
||||
Returns:
|
||||
True if content should be buffered
|
||||
"""
|
||||
if self.in_thinking_tag:
|
||||
return False
|
||||
return bool(
|
||||
self.pending_buffer
|
||||
or self.tool_call_start_token in delta_text
|
||||
or self.tool_call_end_token in delta_text
|
||||
or delta_text.startswith("<")
|
||||
)
|
||||
|
||||
def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Split delta text into safe content and potential tag content.
|
||||
|
||||
Args:
|
||||
delta_text: Delta text to split
|
||||
|
||||
Returns:
|
||||
Tuple of (safe_content, potential_tag_content)
|
||||
"""
|
||||
if self.in_thinking_tag:
|
||||
return delta_text, ""
|
||||
|
||||
for tag in [self.tool_call_start_token, self.tool_call_end_token]:
|
||||
for i in range(1, len(tag)):
|
||||
tag_prefix = tag[:i]
|
||||
pos = delta_text.rfind(tag_prefix)
|
||||
if pos != -1 and tag.startswith(delta_text[pos:]):
|
||||
return delta_text[:pos], delta_text[pos:]
|
||||
return delta_text, ""
|
||||
|
||||
def _process_buffer(self, new_content: str) -> str:
|
||||
"""
|
||||
Process buffered content and return output content.
|
||||
|
||||
Args:
|
||||
new_content: New content to add to buffer
|
||||
|
||||
Returns:
|
||||
Processed output content
|
||||
"""
|
||||
self.pending_buffer += new_content
|
||||
output_content = ""
|
||||
|
||||
if self.in_thinking_tag:
|
||||
output_content = self.pending_buffer
|
||||
self.pending_buffer = ""
|
||||
return output_content
|
||||
|
||||
while self.pending_buffer:
|
||||
start_pos = self.pending_buffer.find(self.tool_call_start_token)
|
||||
end_pos = self.pending_buffer.find(self.tool_call_end_token)
|
||||
|
||||
if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
|
||||
tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
|
||||
elif end_pos != -1:
|
||||
tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
|
||||
else:
|
||||
if self._is_potential_tag_start(self.pending_buffer):
|
||||
break
|
||||
output_content += self.pending_buffer
|
||||
self.pending_buffer = ""
|
||||
break
|
||||
|
||||
output_content += self.pending_buffer[:tag_pos]
|
||||
self.pending_buffer = self.pending_buffer[tag_pos + tag_len :]
|
||||
|
||||
return output_content
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset the streaming state to initial values."""
|
||||
self.streaming_state = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def _advance_to_next_tool(self) -> None:
|
||||
"""Advance to the next tool in the streaming sequence."""
|
||||
self.streaming_state["current_tool_index"] = (
|
||||
int(self.streaming_state["current_tool_index"]) + 1
|
||||
)
|
||||
|
||||
def _set_current_tool_index(self, index: int) -> None:
|
||||
"""
|
||||
Set the current tool index.
|
||||
|
||||
Args:
|
||||
index: Tool index to set
|
||||
"""
|
||||
self.streaming_state["current_tool_index"] = index
|
||||
|
||||
def _get_current_tool_index(self) -> int:
|
||||
"""
|
||||
Get the current tool index.
|
||||
|
||||
Returns:
|
||||
Current tool index
|
||||
"""
|
||||
return int(self.streaming_state["current_tool_index"])
|
||||
|
||||
def _get_next_unsent_tool_index(self, tool_count: int) -> int:
|
||||
"""
|
||||
Get the index of the next unsent tool.
|
||||
|
||||
Args:
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
Index of next unsent tool, or -1 if all tools sent
|
||||
"""
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
for i in range(tool_count):
|
||||
if i < len(sent_tools):
|
||||
if not sent_tools[i]["sent_name"]:
|
||||
return i
|
||||
else:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _ensure_state_arrays(self, tool_count: int) -> None:
|
||||
"""
|
||||
Ensure state arrays have sufficient capacity for tool_count tools.
|
||||
|
||||
Args:
|
||||
tool_count: Number of tools to prepare for
|
||||
"""
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
tool_ids = list(self.streaming_state["tool_ids"])
|
||||
|
||||
while len(sent_tools) < tool_count:
|
||||
sent_tools.append(
|
||||
{
|
||||
"sent_name": False,
|
||||
"sent_arguments": "",
|
||||
"id": make_tool_call_id(),
|
||||
}
|
||||
)
|
||||
|
||||
while len(tool_ids) < tool_count:
|
||||
tool_ids.append(None)
|
||||
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
self.streaming_state["tool_ids"] = tool_ids
|
||||
|
||||
def _detect_tools_in_text(self, text: str) -> int:
|
||||
"""
|
||||
Detect the number of tools in text by counting name patterns.
|
||||
|
||||
Args:
|
||||
text: Text to analyze
|
||||
|
||||
Returns:
|
||||
Number of tools detected
|
||||
"""
|
||||
matches = self.tool_name_pattern.findall(text)
|
||||
return len(matches)
|
||||
|
||||
def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Find the boundaries of tool calls in text.
|
||||
|
||||
Args:
|
||||
text: Text to analyze
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions for tool calls
|
||||
"""
|
||||
boundaries = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
if text[i] == "{":
|
||||
start = i
|
||||
depth = 0
|
||||
has_name = False
|
||||
has_arguments = False
|
||||
|
||||
while i < len(text):
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i + 1
|
||||
segment = text[start:end]
|
||||
if '"name"' in segment and '"arguments"' in segment:
|
||||
boundaries.append((start, end))
|
||||
break
|
||||
|
||||
if not has_name and '"name"' in text[start : i + 1]:
|
||||
has_name = True
|
||||
if not has_arguments and '"arguments"' in text[start : i + 1]:
|
||||
has_arguments = True
|
||||
|
||||
i += 1
|
||||
|
||||
if depth > 0 and has_name:
|
||||
boundaries.append((start, i))
|
||||
else:
|
||||
i += 1
|
||||
return boundaries
|
||||
|
||||
def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str:
|
||||
"""
|
||||
Extract tool arguments from tool content.
|
||||
|
||||
Args:
|
||||
tool_content: Tool call content
|
||||
args_match: Regex match for arguments pattern
|
||||
|
||||
Returns:
|
||||
Extracted arguments as string
|
||||
"""
|
||||
args_start_pos = args_match.end()
|
||||
remaining_content = tool_content[args_start_pos:]
|
||||
|
||||
if remaining_content.strip().startswith("{"):
|
||||
depth = 0
|
||||
for i, char in enumerate(remaining_content):
|
||||
if char == "{":
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return remaining_content[: i + 1]
|
||||
else:
|
||||
args_end = remaining_content.find("}")
|
||||
if args_end > 0:
|
||||
return remaining_content[:args_end].strip()
|
||||
|
||||
return remaining_content.rstrip("}").strip()
|
||||
|
||||
def _get_current_tool_content(
|
||||
self, text: str, tool_index: int
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Get the content of a specific tool by index.
|
||||
|
||||
Args:
|
||||
text: Text containing tool calls
|
||||
tool_index: Index of tool to extract
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_arguments) or (None, None) if not found
|
||||
"""
|
||||
boundaries = self._find_tool_boundaries(text)
|
||||
|
||||
if tool_index >= len(boundaries):
|
||||
return None, None
|
||||
|
||||
start, end = boundaries[tool_index]
|
||||
tool_content = text[start:end]
|
||||
|
||||
name_match = self.tool_name_pattern.search(tool_content)
|
||||
name = name_match.group(1) if name_match else None
|
||||
|
||||
args_match = self.tool_args_pattern.search(tool_content)
|
||||
if args_match:
|
||||
try:
|
||||
args_text = self._extract_tool_args(tool_content, args_match)
|
||||
return name, args_text
|
||||
except Exception:
|
||||
remaining_content = tool_content[args_match.end() :]
|
||||
args_text = remaining_content.rstrip("}").strip()
|
||||
return name, args_text
|
||||
|
||||
return name, None
|
||||
|
||||
def _handle_tool_name_streaming(
|
||||
self, tool_content: str, tool_count: int
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Handle streaming of tool names.
|
||||
|
||||
Args:
|
||||
tool_content: Content containing tool calls
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
DeltaMessage with tool name or None if no tool to stream
|
||||
"""
|
||||
next_idx = self._get_next_unsent_tool_index(tool_count)
|
||||
|
||||
if next_idx == -1:
|
||||
return None
|
||||
|
||||
boundaries = self._find_tool_boundaries(tool_content)
|
||||
if next_idx >= len(boundaries):
|
||||
return None
|
||||
|
||||
tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
self._set_current_tool_index(next_idx)
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
tool_ids = list(self.streaming_state["tool_ids"])
|
||||
|
||||
tool_id = sent_tools[next_idx]["id"]
|
||||
tool_ids[next_idx] = tool_id
|
||||
sent_tools[next_idx]["sent_name"] = True
|
||||
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
self.streaming_state["tool_ids"] = tool_ids
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=next_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(name=tool_name).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _handle_tool_args_streaming(
|
||||
self, tool_content: str, tool_count: int
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Handle streaming of tool arguments.
|
||||
|
||||
Args:
|
||||
tool_content: Content containing tool calls
|
||||
tool_count: Total number of tools
|
||||
|
||||
Returns:
|
||||
DeltaMessage with tool arguments or None if no arguments to stream
|
||||
"""
|
||||
current_idx = self._get_current_tool_index()
|
||||
|
||||
if current_idx < 0 or current_idx >= tool_count:
|
||||
return None
|
||||
|
||||
tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx)
|
||||
if not tool_name or tool_args is None:
|
||||
return None
|
||||
|
||||
sent_tools = list(self.streaming_state["sent_tools"])
|
||||
|
||||
if not sent_tools[current_idx]["sent_name"]:
|
||||
return None
|
||||
|
||||
clean_args = self._clean_duplicate_braces(tool_args)
|
||||
sent_args = sent_tools[current_idx]["sent_arguments"]
|
||||
|
||||
if clean_args != sent_args:
|
||||
if sent_args and clean_args.startswith(sent_args):
|
||||
args_delta = extract_intermediate_diff(clean_args, sent_args)
|
||||
if args_delta:
|
||||
args_delta = self._clean_delta_braces(args_delta)
|
||||
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
|
||||
if clean_args.endswith("}"):
|
||||
self._advance_to_next_tool()
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif not sent_args and clean_args:
|
||||
clean_args_delta = self._clean_delta_braces(clean_args)
|
||||
sent_tools[current_idx]["sent_arguments"] = clean_args
|
||||
self.streaming_state["sent_tools"] = sent_tools
|
||||
|
||||
if clean_args.endswith("}"):
|
||||
self._advance_to_next_tool()
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=clean_args_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _is_end_tool_calls(self, current_text: str) -> bool:
|
||||
if self.tool_call_end_token not in current_text:
|
||||
return False
|
||||
|
||||
end_token_positions = []
|
||||
search_start = 0
|
||||
while True:
|
||||
pos = current_text.find(self.tool_call_end_token, search_start)
|
||||
if pos == -1:
|
||||
break
|
||||
end_token_positions.append(pos)
|
||||
search_start = pos + 1
|
||||
|
||||
think_regions = []
|
||||
for match in re.finditer(
|
||||
self.thinking_tag_pattern, current_text, flags=re.DOTALL
|
||||
):
|
||||
think_regions.append((match.start(), match.end()))
|
||||
|
||||
for pos in end_token_positions:
|
||||
in_think = any(
|
||||
pos >= t_start and pos < t_end for t_start, t_end in think_regions
|
||||
)
|
||||
if not in_think:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
self._update_thinking_state(current_text)
|
||||
|
||||
if self.in_thinking_tag:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self._should_buffer_content(delta_text):
|
||||
buffered_output = self._process_buffer(delta_text)
|
||||
return DeltaMessage(content=buffered_output) if buffered_output else None
|
||||
|
||||
if self._is_end_tool_calls(current_text):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
safe_content, potential_tag = self._split_content_for_buffering(delta_text)
|
||||
if potential_tag:
|
||||
self.pending_buffer += potential_tag
|
||||
return DeltaMessage(content=safe_content) if safe_content else None
|
||||
|
||||
processed_current_text = self.preprocess_model_output(current_text)
|
||||
|
||||
if self.tool_call_start_token not in processed_current_text:
|
||||
if (
|
||||
self.tool_call_end_token in delta_text
|
||||
and self.tool_call_start_token in current_text
|
||||
):
|
||||
return None
|
||||
if delta_text.strip() == "" and self.tool_call_start_token in current_text:
|
||||
return None
|
||||
if (
|
||||
self._get_current_tool_index() != -1
|
||||
and self.tool_call_end_token in current_text
|
||||
):
|
||||
self._reset_streaming_state()
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if (
|
||||
self.tool_call_start_token_id is not None
|
||||
and self.tool_call_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1
|
||||
):
|
||||
return None
|
||||
|
||||
original_tool_start = self._find_tool_start_outside_thinking(current_text)
|
||||
if original_tool_start is None:
|
||||
return None
|
||||
|
||||
content_before_tools = self._extract_content_before_tools(
|
||||
current_text, delta_text, original_tool_start
|
||||
)
|
||||
if content_before_tools:
|
||||
return DeltaMessage(content=content_before_tools)
|
||||
|
||||
try:
|
||||
tool_content = self._extract_tool_content(current_text, original_tool_start)
|
||||
current_tools_count = self._detect_tools_in_text(tool_content)
|
||||
|
||||
if current_tools_count == 0:
|
||||
return None
|
||||
|
||||
if self._get_current_tool_index() == -1:
|
||||
self._reset_streaming_state()
|
||||
|
||||
self._ensure_state_arrays(current_tools_count)
|
||||
|
||||
return self._handle_tool_name_streaming(
|
||||
tool_content, current_tools_count
|
||||
) or self._handle_tool_args_streaming(tool_content, current_tools_count)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"An unexpected error occurred ", "during streaming tool call handling."
|
||||
)
|
||||
return None
|
||||
|
||||
def _find_tool_start_outside_thinking(self, current_text: str) -> int | None:
|
||||
"""
|
||||
Find the start position of tool calls outside of thinking tags.
|
||||
|
||||
Args:
|
||||
current_text: Current text to search
|
||||
|
||||
Returns:
|
||||
Position of tool call start or None if not found
|
||||
"""
|
||||
search_start = 0
|
||||
while True:
|
||||
pos = current_text.find(self.tool_call_start_token, search_start)
|
||||
if pos == -1:
|
||||
return None
|
||||
|
||||
think_regions = [
|
||||
(m.start(), m.end())
|
||||
for m in re.finditer(
|
||||
r"<think>(.*?)</think>", current_text, flags=re.DOTALL
|
||||
)
|
||||
]
|
||||
in_think = any(
|
||||
pos >= t_start and pos < t_end for t_start, t_end in think_regions
|
||||
)
|
||||
|
||||
if not in_think:
|
||||
return pos
|
||||
|
||||
search_start = pos + 1
|
||||
|
||||
def _extract_content_before_tools(
|
||||
self, current_text: str, delta_text: str, tool_start: int
|
||||
) -> str | None:
|
||||
"""
|
||||
Extract content that appears before tool calls.
|
||||
|
||||
Args:
|
||||
current_text: Current text
|
||||
delta_text: Delta text
|
||||
tool_start: Start position of tools
|
||||
|
||||
Returns:
|
||||
Content before tools or None
|
||||
"""
|
||||
if tool_start > 0:
|
||||
delta_start_pos = len(current_text) - len(delta_text)
|
||||
if delta_start_pos < tool_start:
|
||||
content_part = delta_text
|
||||
if delta_start_pos + len(delta_text) > tool_start:
|
||||
content_part = delta_text[: tool_start - delta_start_pos]
|
||||
return content_part if content_part else None
|
||||
return None
|
||||
|
||||
def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
|
||||
"""
|
||||
Extract tool content from current text starting at tool_start.
|
||||
|
||||
Args:
|
||||
current_text: Current text
|
||||
tool_start: Start position of tool calls
|
||||
|
||||
Returns:
|
||||
Extracted tool content
|
||||
"""
|
||||
tool_content_start = tool_start + len(self.tool_call_start_token)
|
||||
tool_content = current_text[tool_content_start:]
|
||||
|
||||
end_pos = tool_content.find(self.tool_call_end_token)
|
||||
if end_pos != -1:
|
||||
tool_content = tool_content[:end_pos]
|
||||
|
||||
return tool_content
|
||||
585
vllm/tool_parsers/mistral_tool_parser.py
Normal file
585
vllm/tool_parsers/mistral_tool_parser.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, auto
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Any
|
||||
|
||||
import ijson
|
||||
import regex as re
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class StreamingState(Enum):
|
||||
"""Enum for tracking the current streaming parsing state."""
|
||||
|
||||
WAITING_FOR_TOOL_START = auto()
|
||||
WAITING_FOR_TOOL_KEY = (
|
||||
auto()
|
||||
) # waiting for the "name" or "arguments" key to be complete
|
||||
PARSING_NAME = auto()
|
||||
PARSING_NAME_COMPLETED = auto()
|
||||
WAITING_FOR_ARGUMENTS_START = auto()
|
||||
PARSING_ARGUMENTS = auto()
|
||||
PARSING_ARGUMENTS_COMPLETED = auto()
|
||||
TOOL_COMPLETE = auto()
|
||||
ALL_TOOLS_COMPLETE = auto()
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
@staticmethod
|
||||
def is_valid_id(id: str) -> bool:
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
|
||||
return not (
|
||||
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
|
||||
)
|
||||
|
||||
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
|
||||
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
|
||||
- the examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START
|
||||
|
||||
# For streaming pre v11 tokenizer tool calls
|
||||
self.current_tool_name: str | None = None
|
||||
self.current_tool_mistral_id: str | None = None
|
||||
self.starting_new_tool = False
|
||||
if _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
self.parse_coro = ijson.parse_coro(
|
||||
self.update_stream_state_pre_v11_tokenizer()
|
||||
)
|
||||
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!"
|
||||
)
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
not isinstance(self.model_tokenizer, MistralTokenizer)
|
||||
and request.tools
|
||||
and request.tool_choice != "none"
|
||||
):
|
||||
# Do not skip special tokens when using chat template
|
||||
# with Mistral parser as TOOL_CALL token is needed
|
||||
# for tool detection.
|
||||
# Note: we don't want skip_special_tokens=False
|
||||
# with MistralTokenizer as it is incompatible
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# first remove the BOT token
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
try:
|
||||
if not self._is_pre_v11:
|
||||
function_call_arr = []
|
||||
for single_tool_content in model_output.split(self.bot_token):
|
||||
if "{" not in single_tool_content:
|
||||
continue
|
||||
|
||||
end_name = single_tool_content.find("{")
|
||||
fn_name, args = (
|
||||
single_tool_content[:end_name],
|
||||
single_tool_content[end_name:],
|
||||
)
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
# correctly. It's an easy possible fix so it's included, but
|
||||
# can be brittle for very complex / highly nested tool calls
|
||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
|
||||
# Tool Call
|
||||
tool_calls: list[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
raw_function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=tool_content
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if self.bot_token_id not in current_token_ids:
|
||||
# if the tool call token is not in the tokens generated so far,
|
||||
# append output to contents since it's not a tool
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
try:
|
||||
if _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
delta_text=delta_text,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
else:
|
||||
return self._extract_tool_calls_streaming(
|
||||
delta_text=delta_text, delta_token_ids=delta_token_ids
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None
|
||||
|
||||
def _extract_tool_calls_streaming(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extracts tool calls for Mistral models
|
||||
doing tool calls of the following format:
|
||||
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
|
||||
"""
|
||||
additional_content: str = ""
|
||||
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
|
||||
# this is the first tool call
|
||||
assert self.bot_token_id in delta_token_ids
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
additional_content += delta_text.split(self.bot_token)[0]
|
||||
delta_text = self.bot_token + "".join(
|
||||
delta_text.split(self.bot_token)[1:]
|
||||
)
|
||||
|
||||
delta_tool_calls = self._generate_delta_tool_call(delta_text)
|
||||
if not additional_content and len(delta_tool_calls) == 0:
|
||||
if self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
StreamingState.TOOL_COMPLETE,
|
||||
StreamingState.ALL_TOOLS_COMPLETE,
|
||||
]:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage()
|
||||
else:
|
||||
# return None when the tool is not likely to be finished
|
||||
# This can occur when the name is being parsed for example
|
||||
# and we wait for the name to be complete
|
||||
# before sending the function name
|
||||
return None
|
||||
|
||||
delta = DeltaMessage()
|
||||
if additional_content:
|
||||
delta.content = additional_content
|
||||
if len(delta_tool_calls) > 0:
|
||||
delta.tool_calls = delta_tool_calls
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if delta_tool_calls and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
return delta
|
||||
|
||||
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
|
||||
if delta_text == "" or delta_text is None:
|
||||
return []
|
||||
delta_function_name = None
|
||||
tool_id = None
|
||||
if self.streaming_state not in [
|
||||
StreamingState.PARSING_NAME,
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
] and delta_text.startswith(self.bot_token):
|
||||
self.current_tool_id += 1
|
||||
self.streaming_state = StreamingState.PARSING_NAME
|
||||
delta_text = delta_text.replace(self.bot_token, "", 1)
|
||||
if self.streaming_state == StreamingState.PARSING_NAME:
|
||||
if self.current_tool_name is None:
|
||||
self.current_tool_name = ""
|
||||
# The name stops where the arguments start
|
||||
# And the arguments start with the `{` char
|
||||
if "{" in delta_text:
|
||||
tool_id = MistralToolCall.generate_random_id()
|
||||
delta_function_name = delta_text.split("{")[0]
|
||||
self.current_tool_name += delta_function_name
|
||||
delta_text = delta_text[len(delta_function_name) :]
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS
|
||||
else:
|
||||
# we want to send the tool name once it's complete
|
||||
self.current_tool_name += delta_text
|
||||
return []
|
||||
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
|
||||
next_function_text = None
|
||||
if self.bot_token in delta_text:
|
||||
# current tool call is over
|
||||
delta_arguments = ""
|
||||
delta_arguments += delta_text.split(self.bot_token)[0]
|
||||
next_function_text = delta_text[len(delta_arguments) :]
|
||||
self.streaming_state = StreamingState.TOOL_COMPLETE
|
||||
else:
|
||||
delta_arguments = delta_text
|
||||
ret = []
|
||||
if self.current_tool_name or delta_arguments:
|
||||
ret += [
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_tool_name, arguments=delta_arguments
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
self.current_tool_name = None
|
||||
if next_function_text:
|
||||
ret += self._generate_delta_tool_call(next_function_text)
|
||||
return ret
|
||||
# Should not happen
|
||||
return []
|
||||
|
||||
@ijson.coroutine
|
||||
def update_stream_state_pre_v11_tokenizer(self):
|
||||
while True:
|
||||
(prefix, event, value) = yield
|
||||
|
||||
if prefix == "item" and event == "start_map":
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
if prefix == "item" and event == "map_key" and value == "name":
|
||||
self.streaming_state = StreamingState.PARSING_NAME
|
||||
if prefix == "item.name" and event == "string":
|
||||
self.current_tool_name = value
|
||||
self.streaming_state = StreamingState.PARSING_NAME_COMPLETED
|
||||
if prefix == "item" and event == "map_key" and value == "arguments":
|
||||
self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START
|
||||
if prefix == "item.arguments" and event == "start_map":
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS
|
||||
if prefix == "item.arguments" and event == "end_map":
|
||||
self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED
|
||||
if prefix == "item" and event == "end_map":
|
||||
self.streaming_state = StreamingState.TOOL_COMPLETE
|
||||
if prefix == "" and event == "end_array":
|
||||
self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE
|
||||
|
||||
def _extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""
|
||||
Extracts tool calls for Mistral models
|
||||
doing tool calls of the following format:
|
||||
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
|
||||
"""
|
||||
assert self.parse_coro is not None
|
||||
content = None
|
||||
delta_tool_calls: list[DeltaToolCall] = []
|
||||
current_tool_call: DeltaToolCall = DeltaToolCall(
|
||||
index=self.current_tool_id, type="function"
|
||||
)
|
||||
current_tool_call_modified = False
|
||||
if self.bot_token_id in delta_token_ids:
|
||||
# this is the first tool call
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
content = delta_text.split(self.bot_token)[0]
|
||||
delta_text = "".join(delta_text.split(self.bot_token)[1:])
|
||||
|
||||
# Cut smartly the delta text to catch the ijson events
|
||||
# as ijson does not give us the index in the text at each event.
|
||||
# We need to cut so that we know
|
||||
# where in the text the events are emitted from.
|
||||
while len(delta_text) > 0:
|
||||
streaming_state_before_parse = self.streaming_state
|
||||
|
||||
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY:
|
||||
# Wait until another key is sent
|
||||
# or the current tool is completed
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_colon=1,
|
||||
stop_after_opening_curly_braces=1,
|
||||
# if the tool ends, we want to separate
|
||||
# at the start of the next tool
|
||||
)
|
||||
elif self.streaming_state == StreamingState.PARSING_NAME:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_comma=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.PARSING_ARGUMENTS:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_closing_curly_braces=1,
|
||||
# we could be more clever
|
||||
# by listening to item.arguments.* start_map events
|
||||
# and know how many curly braces we can allow
|
||||
)
|
||||
elif self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
StreamingState.PARSING_NAME_COMPLETED,
|
||||
]:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_closing_curly_braces=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.TOOL_COMPLETE:
|
||||
delta_to_be_parsed, delta_text = self._split_delta(
|
||||
delta_text=delta_text,
|
||||
stop_after_opening_curly_braces=1,
|
||||
stop_after_closing_brackets=1,
|
||||
)
|
||||
elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
|
||||
content = delta_text
|
||||
delta_text = ""
|
||||
else:
|
||||
delta_to_be_parsed = delta_text
|
||||
delta_text = ""
|
||||
|
||||
if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE:
|
||||
self.parse_coro.send(delta_to_be_parsed.encode("utf-8"))
|
||||
|
||||
# Given the parsed text and the possible streaming state change,
|
||||
# let's add to the tool delta
|
||||
if (
|
||||
(streaming_state_before_parse != self.streaming_state)
|
||||
and streaming_state_before_parse
|
||||
in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE]
|
||||
and self.streaming_state
|
||||
not in [
|
||||
StreamingState.ALL_TOOLS_COMPLETE,
|
||||
StreamingState.TOOL_COMPLETE,
|
||||
StreamingState.WAITING_FOR_TOOL_START,
|
||||
]
|
||||
):
|
||||
# starting a new tool call
|
||||
if current_tool_call_modified:
|
||||
if self.current_tool_mistral_id is not None:
|
||||
current_tool_call.id = self.current_tool_mistral_id
|
||||
self.current_tool_mistral_id = None
|
||||
delta_tool_calls.append(current_tool_call)
|
||||
current_tool_call_modified = False
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_mistral_id = MistralToolCall.generate_random_id()
|
||||
current_tool_call = DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
)
|
||||
if current_tool_call.function is None:
|
||||
current_tool_call.function = DeltaFunctionCall()
|
||||
|
||||
if self.current_tool_name is not None:
|
||||
# we have the complete tool name
|
||||
current_tool_call_modified = True
|
||||
current_tool_call.function.name = self.current_tool_name
|
||||
self.current_tool_name = None
|
||||
if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED:
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
if self.streaming_state in [
|
||||
StreamingState.PARSING_ARGUMENTS,
|
||||
StreamingState.PARSING_ARGUMENTS_COMPLETED,
|
||||
]:
|
||||
if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED:
|
||||
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
|
||||
# the delta_to_be_parsed is part of arguments.
|
||||
current_tool_call_modified = True
|
||||
if current_tool_call.function.arguments is None:
|
||||
current_tool_call.function.arguments = delta_to_be_parsed
|
||||
else:
|
||||
current_tool_call.function.arguments += delta_to_be_parsed
|
||||
if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS:
|
||||
# It's the first chunk of arg. let's lstrip it
|
||||
current_tool_call.function.arguments = (
|
||||
current_tool_call.function.arguments.lstrip()
|
||||
)
|
||||
|
||||
if current_tool_call_modified:
|
||||
if self.current_tool_mistral_id is not None:
|
||||
current_tool_call.id = self.current_tool_mistral_id
|
||||
self.current_tool_mistral_id = None
|
||||
delta_tool_calls.append(current_tool_call)
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining it's final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if delta_tool_calls and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if content or len(delta_tool_calls) > 0:
|
||||
delta_message = DeltaMessage()
|
||||
if content:
|
||||
delta_message.content = content
|
||||
if len(delta_tool_calls) > 0:
|
||||
delta_message.tool_calls = delta_tool_calls
|
||||
return delta_message
|
||||
else:
|
||||
if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
|
||||
return DeltaMessage()
|
||||
else:
|
||||
return None
|
||||
|
||||
def _split_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
stop_after_quotes: int = -1,
|
||||
stop_after_opening_curly_braces: int = -1,
|
||||
stop_after_closing_curly_braces: int = -1,
|
||||
stop_after_closing_brackets: int = -1,
|
||||
stop_after_colon: int = -1,
|
||||
stop_after_comma=-1,
|
||||
) -> tuple[str, str]:
|
||||
delta_to_be_parsed = ""
|
||||
for i, c in enumerate(delta_text):
|
||||
if c in ['"', "'"]:
|
||||
delta_to_be_parsed += c
|
||||
stop_after_quotes -= 1
|
||||
if stop_after_quotes == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "{":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_opening_curly_braces -= 1
|
||||
if stop_after_opening_curly_braces == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "}":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_closing_curly_braces -= 1
|
||||
if stop_after_closing_curly_braces == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == "]":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_closing_brackets -= 1
|
||||
if stop_after_closing_brackets == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == ":":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_colon -= 1
|
||||
if stop_after_colon == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
elif c == ",":
|
||||
delta_to_be_parsed += c
|
||||
stop_after_comma -= 1
|
||||
if stop_after_comma == 0:
|
||||
return (delta_to_be_parsed, delta_text[i + 1 :])
|
||||
else:
|
||||
delta_to_be_parsed += c
|
||||
|
||||
return (delta_to_be_parsed, "")
|
||||
366
vllm/tool_parsers/olmo3_tool_parser.py
Normal file
366
vllm/tool_parsers/olmo3_tool_parser.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Olmo3PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Olmo 3 models that produce tool calls as
|
||||
newline-separated pythonic strings.
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
Code copied from pythonic_tool_parser.py and updated to handle
|
||||
- newline separated pythonic tool calls.
|
||||
- argument values being null/true/false instead of Pythonic literals.
|
||||
"""
|
||||
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
original_model_output = model_output
|
||||
# Remove xml tags.
|
||||
match = re.search(
|
||||
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
|
||||
)
|
||||
if match:
|
||||
model_output = match.group(1).strip()
|
||||
# Make the newline separated function calls into a list.
|
||||
model_output = ", ".join(
|
||||
[line.strip() for line in model_output.splitlines() if line.strip()]
|
||||
)
|
||||
model_output = f"[{model_output}]"
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = (
|
||||
self.TOOL_CALL_REGEX.match(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
)
|
||||
is not None
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# All function calls start with the <function_calls> tag.
|
||||
# But since this is streaming, we may have seen only part of the tag.
|
||||
if not current_text.startswith("<"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# Remove xml tags.
|
||||
if current_text.startswith("<function_calls>"):
|
||||
current_text = current_text[len("<function_calls>") :]
|
||||
if current_text.endswith("</function_calls>"):
|
||||
current_text = current_text[: -len("</function_calls>")]
|
||||
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
# Make the newline separated function calls into a list.
|
||||
valid_text = ", ".join(
|
||||
[line.strip() for line in valid_text.splitlines() if line.strip()]
|
||||
)
|
||||
valid_text = f"[{valid_text}]"
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a sequence of newline-separated calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = added_text[:-1] if not new_call_complete else ""
|
||||
if not new_call_complete and added_text[-1] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (
|
||||
delta.function is not None
|
||||
and delta.function.arguments is not None
|
||||
):
|
||||
self.streamed_args_for_tool[index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
# The model may return function calls where the values are null/true/false
|
||||
# because the system prompt has API description in json.
|
||||
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
|
||||
if val.id == "null":
|
||||
return None
|
||||
elif val.id == "true":
|
||||
return True
|
||||
elif val.id == "false":
|
||||
return False
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
102
vllm/tool_parsers/openai_tool_parser.py
Normal file
102
vllm/tool_parsers/openai_tool_parser.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
else:
|
||||
TokenizerLike = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: "TokenizerLike"):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
token_ids: Sequence[int] | None = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if token_ids is None:
|
||||
raise NotImplementedError(
|
||||
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
|
||||
)
|
||||
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
tool_calls = []
|
||||
final_content = None
|
||||
commentary_content = None
|
||||
|
||||
if len(parser.messages) > 0:
|
||||
for msg in parser.messages:
|
||||
if len(msg.content) < 1:
|
||||
continue
|
||||
msg_text = msg.content[0].text
|
||||
if msg.recipient and msg.recipient.startswith("functions."):
|
||||
# If no content-type is given assume JSON, as that's the
|
||||
# most common case with gpt-oss models.
|
||||
if not msg.content_type or "json" in msg.content_type:
|
||||
# load and dump the JSON text to check validity and
|
||||
# remove any extra newlines or other odd formatting
|
||||
try:
|
||||
tool_args = json.dumps(json.loads(msg_text))
|
||||
except json.JSONDecodeError:
|
||||
logger.exception(
|
||||
"Error decoding JSON tool call from response."
|
||||
)
|
||||
tool_args = msg_text
|
||||
else:
|
||||
tool_args = msg_text
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=msg.recipient.split("functions.")[1],
|
||||
arguments=tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif msg.channel == "final":
|
||||
final_content = msg_text
|
||||
elif msg.channel == "commentary" and not msg.recipient:
|
||||
commentary_content = msg_text
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
# prefer final content over commentary content if both are present
|
||||
# commentary content is tool call preambles meant to be shown to the user
|
||||
content=final_content or commentary_content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
raise NotImplementedError(
|
||||
"Not being used, manual parsing in serving_chat.py" # noqa: E501
|
||||
)
|
||||
120
vllm/tool_parsers/phi4mini_tool_parser.py
Normal file
120
vllm/tool_parsers/phi4mini_tool_parser.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Phi4MiniJsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for phi-4-mini models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token: str = "functools"
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
logger.debug("Model output: %s", model_output)
|
||||
|
||||
pattern = r"functools\[(.*?)\]"
|
||||
matches = re.search(pattern, model_output, re.DOTALL)
|
||||
|
||||
if not matches:
|
||||
logger.debug("No function calls found")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
function_call_arr: list[dict[str, Any]] = []
|
||||
try:
|
||||
json_content = "[" + matches.group(1) + "]"
|
||||
|
||||
function_call_arr = json.loads(json_content)
|
||||
logger.debug(
|
||||
"Successfully extracted %d function calls", len(function_call_arr)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to parse function calls from model output. Error: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
raw_function_call["arguments"]
|
||||
if "arguments" in raw_function_call
|
||||
else raw_function_call["parameters"],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=None
|
||||
)
|
||||
return ret
|
||||
|
||||
except Exception:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
return None
|
||||
332
vllm/tool_parsers/pythonic_tool_parser.py
Normal file
332
vllm/tool_parsers/pythonic_tool_parser.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for models that produce tool calls in a pythonic style,
|
||||
such as Llama 3.2 and Llama 4 models.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
"""
|
||||
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = (
|
||||
self.TOOL_CALL_REGEX.match(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
)
|
||||
is not None
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if not current_text.startswith("["):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = (
|
||||
index < len(tool_calls) - 1 or ")]" not in added_text
|
||||
)
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = added_text[:-2] if not new_call_complete else ""
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (
|
||||
delta.function is not None
|
||||
and delta.function.arguments is not None
|
||||
):
|
||||
self.streamed_args_for_tool[index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
781
vllm/tool_parsers/qwen3coder_tool_parser.py
Normal file
781
vllm/tool_parsers/qwen3coder_tool_parser.py
Normal file
@@ -0,0 +1,781 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3CoderToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# Override base class type - we use string IDs for tool calls
|
||||
self.current_tool_id: str | None = None # type: ignore
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
# Sentinel tokens for streaming mode
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.is_tool_call_started: bool = False
|
||||
self.failed_count: int = 0
|
||||
|
||||
# Enhanced streaming state - reset for each new message
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Regex patterns
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>", re.DOTALL
|
||||
)
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Qwen3 XML Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
# Store accumulated parameters for type conversion
|
||||
self.accumulated_params = {}
|
||||
self.streaming_request = None
|
||||
|
||||
def _get_arguments_config(
|
||||
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
|
||||
) -> dict:
|
||||
"""Extract argument configuration for a function."""
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (
|
||||
hasattr(config, "function") and hasattr(config.function, "name")
|
||||
):
|
||||
continue
|
||||
if config.type == "function" and config.function.name == func_name:
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
|
||||
return {}
|
||||
|
||||
def _convert_param_value(
|
||||
self, param_value: str, param_name: str, param_config: dict, func_name: str
|
||||
) -> Any:
|
||||
"""Convert parameter value based on its type in the schema."""
|
||||
# Handle null value for any type
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.debug(
|
||||
"Parsed parameter '%s' is not defined in the tool "
|
||||
"parameters for tool '%s', directly returning the "
|
||||
"string value.",
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
|
||||
if (
|
||||
isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]
|
||||
):
|
||||
param_type = str(param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
||||
return param_value
|
||||
elif (
|
||||
param_type.startswith("int")
|
||||
or param_type.startswith("uint")
|
||||
or param_type.startswith("long")
|
||||
or param_type.startswith("short")
|
||||
or param_type.startswith("unsigned")
|
||||
):
|
||||
try:
|
||||
return int(param_value)
|
||||
except (ValueError, TypeError):
|
||||
logger.debug(
|
||||
"Parsed value '%s' of parameter '%s' is not an "
|
||||
"integer in tool '%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith("float"):
|
||||
try:
|
||||
float_param_value = float(param_value)
|
||||
return (
|
||||
float_param_value
|
||||
if float_param_value - int(float_param_value) != 0
|
||||
else int(float_param_value)
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
logger.debug(
|
||||
"Parsed value '%s' of parameter '%s' is not a float "
|
||||
"in tool '%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
param_value = param_value.lower()
|
||||
if param_value not in ["true", "false"]:
|
||||
logger.debug(
|
||||
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||
"(`true` or `false`) in tool '%s', degenerating to "
|
||||
"false.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value == "true"
|
||||
else:
|
||||
if (
|
||||
param_type in ["object", "array", "arr"]
|
||||
or param_type.startswith("dict")
|
||||
or param_type.startswith("list")
|
||||
):
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
return param_value
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
logger.debug(
|
||||
"Parsed value '%s' of parameter '%s' cannot be "
|
||||
"parsed with json.loads in tool '%s', will try "
|
||||
"other methods to parse it.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
try:
|
||||
param_value = ast.literal_eval(param_value) # safer
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
logger.debug(
|
||||
"Parsed value '%s' of parameter '%s' cannot be "
|
||||
"converted via Python `ast.literal_eval()` in tool "
|
||||
"'%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
|
||||
) -> ToolCall | None:
|
||||
# Extract function name
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = self._get_arguments_config(function_name, tools)
|
||||
parameters = function_call_str[end_index + 1 :]
|
||||
param_dict = {}
|
||||
for match_text in self.tool_call_parameter_regex.findall(parameters):
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1 :])
|
||||
# Remove prefix and trailing \n
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_value, param_name, param_config, function_name
|
||||
)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||
# Find all tool calls
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
|
||||
# Back-off strategy if no tool_call tags found
|
||||
if len(raw_tool_calls) == 0:
|
||||
raw_tool_calls = [model_output]
|
||||
|
||||
raw_function_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
|
||||
|
||||
function_calls = [
|
||||
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||
]
|
||||
return function_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# Quick check to avoid unnecessary processing
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
function_calls = self._get_function_calls(model_output)
|
||||
if len(function_calls) == 0:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(function_call_str, request.tools)
|
||||
for function_call_str in function_calls
|
||||
]
|
||||
|
||||
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||
for tool_call in tool_calls:
|
||||
if tool_call:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract content before tool calls
|
||||
content_index = model_output.find(self.tool_call_start_token)
|
||||
idx = model_output.find(self.tool_call_prefix)
|
||||
content_index = content_index if content_index >= 0 else idx
|
||||
content = model_output[:content_index] # .rstrip()
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=(len(tool_calls) > 0),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# Store request for type conversion
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
# If no delta text, return None unless it's an EOS token after tools
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
# Check for tool calls in text even if is_tool_call_started
|
||||
# is False (might have been reset after processing all tools)
|
||||
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text)
|
||||
)
|
||||
|
||||
# If we have completed tool calls and populated
|
||||
# prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token
|
||||
) - current_text.count(self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta for finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.accumulated_params = {}
|
||||
|
||||
# Check if there are more tool calls
|
||||
tool_starts = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts:
|
||||
# No more tool calls
|
||||
self.is_tool_call_started = False
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (
|
||||
self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text
|
||||
):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[
|
||||
: delta_text.index(self.tool_call_start_token)
|
||||
]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (
|
||||
current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""
|
||||
):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
# Count tool calls we've seen vs processed
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# We're in a tool call, find the current tool call portion
|
||||
# Need to find the correct tool call based on current_tool_index
|
||||
tool_start_positions: list[int] = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = current_text.find(self.tool_call_start_token, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
tool_start_positions.append(idx)
|
||||
idx += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_start_positions):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_start_positions[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[
|
||||
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
|
||||
]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix
|
||||
)
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# IMPORTANT: Add to prev_tool_call_arr immediately when
|
||||
# we detect a tool call. This ensures
|
||||
# finish_reason="tool_calls" even if parsing isn't complete
|
||||
already_added = any(
|
||||
tool.get("name") == self.current_function_name
|
||||
for tool in self.prev_tool_call_arr
|
||||
)
|
||||
if not already_added:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}", # Placeholder, will be updated later
|
||||
}
|
||||
)
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if not self.json_started and self.parameter_prefix not in delta_text:
|
||||
self.json_started = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract complete tool call to update
|
||||
# prev_tool_call_arr with final arguments
|
||||
# Find the function content
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix
|
||||
)
|
||||
func_content_end = tool_text.find(self.function_end_token, func_start)
|
||||
if func_content_end != -1:
|
||||
func_content = tool_text[func_start:func_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
func_content,
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request
|
||||
else None,
|
||||
)
|
||||
if parsed_tool:
|
||||
# Update existing entry in
|
||||
# prev_tool_call_arr with complete args
|
||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||
if tool.get("name") == parsed_tool.function.name:
|
||||
args = parsed_tool.function.arguments
|
||||
self.prev_tool_call_arr[i]["arguments"] = args
|
||||
break
|
||||
except Exception:
|
||||
pass # Ignore parsing errors during streaming
|
||||
|
||||
result = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Reset state for next tool
|
||||
self.in_function = False
|
||||
self.json_closed = True
|
||||
self.accumulated_params = {}
|
||||
|
||||
return result
|
||||
|
||||
# Look for parameters
|
||||
# Find all parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if (
|
||||
not self.in_param
|
||||
and self.param_count < len(param_starts)
|
||||
and len(param_starts) > self.param_count
|
||||
):
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
self.current_param_name = remaining[:name_end]
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
# No closing tag, look for next parameter or
|
||||
# function end
|
||||
next_param_idx = value_text.find(self.parameter_prefix)
|
||||
func_end_idx = value_text.find(self.function_end_token)
|
||||
|
||||
if next_param_idx != -1 and (
|
||||
func_end_idx == -1 or next_param_idx < func_end_idx
|
||||
):
|
||||
param_end_idx = next_param_idx
|
||||
elif func_end_idx != -1:
|
||||
param_end_idx = func_end_idx
|
||||
else:
|
||||
# Neither found, check if tool call is complete
|
||||
if self.tool_call_end_token in tool_text:
|
||||
# Tool call is complete, so parameter
|
||||
# must be complete too. Use all
|
||||
# remaining text before function end
|
||||
param_end_idx = len(value_text)
|
||||
else:
|
||||
# Still streaming, wait for more content
|
||||
return None
|
||||
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Store raw value for later processing
|
||||
self.accumulated_params[self.current_param_name] = param_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = self._get_arguments_config(
|
||||
self.current_function_name or "",
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request
|
||||
else None,
|
||||
)
|
||||
|
||||
# Convert param value to appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
param_value,
|
||||
self.current_param_name,
|
||||
param_config,
|
||||
self.current_function_name or "",
|
||||
)
|
||||
|
||||
# Build JSON fragment based on the converted type
|
||||
# Use json.dumps to properly serialize the value
|
||||
serialized_value = json.dumps(
|
||||
converted_value, ensure_ascii=False
|
||||
)
|
||||
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
f'"{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
else:
|
||||
json_fragment = (
|
||||
f', "{self.current_param_name}": {serialized_value}'
|
||||
)
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments=json_fragment),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Continue parameter value - Not used in the current implementation
|
||||
# since we process complete parameters above
|
||||
if self.in_param:
|
||||
if self.parameter_end_token in delta_text:
|
||||
# End of parameter
|
||||
end_idx = delta_text.find(self.parameter_end_token)
|
||||
value_chunk = delta_text[:end_idx]
|
||||
|
||||
# Skip past > if at start
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1 :]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
# Store complete value
|
||||
full_value = self.current_param_value + value_chunk
|
||||
self.accumulated_params[self.current_param_name] = full_value
|
||||
|
||||
# Get parameter configuration for type conversion
|
||||
param_config = self._get_arguments_config(
|
||||
self.current_function_name or "",
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request
|
||||
else None,
|
||||
)
|
||||
|
||||
# Convert the parameter value to the appropriate type
|
||||
converted_value = self._convert_param_value(
|
||||
full_value,
|
||||
self.current_param_name or "",
|
||||
param_config,
|
||||
self.current_function_name or "",
|
||||
)
|
||||
|
||||
# Serialize the converted value
|
||||
serialized_value = json.dumps(converted_value, ensure_ascii=False)
|
||||
|
||||
# Since we've been streaming the quoted version,
|
||||
# we need to close it properly
|
||||
# This is complex - for now just complete the value
|
||||
self.in_param = False
|
||||
self.current_param_value = ""
|
||||
|
||||
# Just close the current parameter string
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments='"'
|
||||
), # Close the string quote
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Continue accumulating value
|
||||
value_chunk = delta_text
|
||||
|
||||
# Handle first chunk after param name
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1 :]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
if value_chunk:
|
||||
# Stream the escaped delta
|
||||
prev_escaped = (
|
||||
json.dumps(self.current_param_value, ensure_ascii=False)[
|
||||
1:-1
|
||||
]
|
||||
if self.current_param_value
|
||||
else ""
|
||||
)
|
||||
self.current_param_value += value_chunk
|
||||
full_escaped = json.dumps(
|
||||
self.current_param_value, ensure_ascii=False
|
||||
)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped) :]
|
||||
|
||||
if delta_escaped:
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
1316
vllm/tool_parsers/qwen3xml_tool_parser.py
Normal file
1316
vllm/tool_parsers/qwen3xml_tool_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
744
vllm/tool_parsers/seed_oss_tool_parser.py
Normal file
744
vllm/tool_parsers/seed_oss_tool_parser.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from qwen3coder xml parser, All rights reserved.
|
||||
# ruff: noqa: E501
|
||||
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SeedOssToolParser(ToolParser):
|
||||
TOOL_CALL_START = "<seed:tool_call>"
|
||||
TOOL_CALL_END = "</seed:tool_call>"
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# --- streaming state ---
|
||||
self._reset_streaming_state()
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
self.tool_call_start_token: str = self.TOOL_CALL_START
|
||||
self.tool_call_end_token: str = self.TOOL_CALL_END
|
||||
# Sentinel tokens for streaming mode
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.think_start_token: str = "<seed:think>"
|
||||
self.think_end_token: str = "</seed:think>"
|
||||
self.is_tool_call_started: bool = False
|
||||
self.is_thinking_end: bool = False
|
||||
self.failed_count: int = 0
|
||||
self._reset_streaming_state()
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Seed_Oss XML parser: tokenizer did not include "
|
||||
"<seed:tool_call> or its closing tag."
|
||||
)
|
||||
|
||||
tool_start_re = re.escape(self.tool_call_start_token)
|
||||
tool_end_re = re.escape(self.tool_call_end_token)
|
||||
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL
|
||||
)
|
||||
self.tool_call_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"vLLM Seed-Oss XML tool parser loaded (%s).", self.__class__.__name__
|
||||
)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = -1
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
|
||||
) -> ToolCall | None:
|
||||
def get_arguments_config(func_name: str) -> dict:
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (
|
||||
hasattr(config, "function") and hasattr(config.function, "name")
|
||||
):
|
||||
continue
|
||||
if config.type == "function" and config.function.name == func_name:
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
|
||||
return {}
|
||||
|
||||
def convert_param_value(
|
||||
param_value: str, param_name: str, param_config: dict, func_name: str
|
||||
) -> Any:
|
||||
# Handle null value for any type
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.warning(
|
||||
"Parsed parameter '%s' is not defined in "
|
||||
"the tool parameters for tool '%s', "
|
||||
"directly returning the string value.",
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
|
||||
if (
|
||||
isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]
|
||||
):
|
||||
param_type = str(param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
||||
return param_value
|
||||
elif (
|
||||
param_type.startswith("int")
|
||||
or param_type.startswith("uint")
|
||||
or param_type.startswith("long")
|
||||
or param_type.startswith("short")
|
||||
or param_type.startswith("unsigned")
|
||||
):
|
||||
try:
|
||||
param_value = int(param_value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not an integer in tool "
|
||||
"'%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith("float"):
|
||||
try:
|
||||
float_param_value = float(param_value)
|
||||
param_value = (
|
||||
float_param_value # type: ignore
|
||||
if float_param_value - int(float_param_value) != 0
|
||||
else int(float_param_value) # type: ignore
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a float in tool "
|
||||
"'%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
param_value = param_value.lower()
|
||||
if param_value not in ["true", "false"]:
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||
"(`true` of `false`) in tool '%s', degenerating to false.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value == "true"
|
||||
else:
|
||||
if param_type == "object" or param_type.startswith("dict"):
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
return param_value
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a valid JSON "
|
||||
"object in tool '%s', will try other methods to parse it.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
try:
|
||||
param_value = ast.literal_eval(param_value)
|
||||
except (ValueError, SyntaxError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' cannot be converted via "
|
||||
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
|
||||
param_value,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return param_value
|
||||
|
||||
# Extract function name
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = get_arguments_config(function_name)
|
||||
parameters = function_call_str[end_index + 1 :]
|
||||
param_dict = {}
|
||||
for match in self.tool_call_parameter_regex.findall(parameters):
|
||||
match_text = match[0] if match[0] else match[1]
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1 :])
|
||||
# Remove prefix and trailing \n
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = convert_param_value(
|
||||
param_value, param_name, param_config, function_name
|
||||
)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||
# Find all tool calls
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
|
||||
# Back-off strategy if no tool_call tags found
|
||||
if len(raw_tool_calls) == 0:
|
||||
raw_tool_calls = [model_output]
|
||||
|
||||
raw_function_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
|
||||
|
||||
function_calls = [
|
||||
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||
]
|
||||
return function_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# Quick check to avoid unnecessary processing
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Check if both think start and end tokens are present
|
||||
if (
|
||||
self.think_start_token in model_output
|
||||
and self.think_end_token in model_output
|
||||
):
|
||||
# Find the position of think end token
|
||||
think_end_index = model_output.find(self.think_end_token) + len(
|
||||
self.think_end_token
|
||||
)
|
||||
# Extract content after think end token
|
||||
result_content = model_output[think_end_index:]
|
||||
thinking_content = model_output[:think_end_index]
|
||||
else:
|
||||
thinking_content = ""
|
||||
result_content = model_output
|
||||
|
||||
try:
|
||||
function_calls = self._get_function_calls(result_content)
|
||||
if len(function_calls) == 0:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(function_call_str, request.tools)
|
||||
for function_call_str in function_calls
|
||||
]
|
||||
|
||||
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||
for tool_call in tool_calls:
|
||||
if tool_call:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract content before tool calls
|
||||
tool_call_start_index = result_content.find(self.tool_call_start_token)
|
||||
tool_call_start_index = (
|
||||
tool_call_start_index
|
||||
if tool_call_start_index >= 0
|
||||
else result_content.find(self.tool_call_prefix)
|
||||
)
|
||||
content = thinking_content + result_content[:tool_call_start_index]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=(len(tool_calls) > 0),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# If no delta text, return None unless
|
||||
# it's an EOS token after tool calls
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
# We check for tool calls in the text even if is_tool_call_started
|
||||
# is False because it might have been reset after processing all tools
|
||||
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text)
|
||||
)
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token
|
||||
) - current_text.count(self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta message to allow finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Check if this is the first call (reset state if needed)
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
# Check if there are more tool calls
|
||||
if self.current_tool_index >= current_text.count(
|
||||
self.tool_call_start_token
|
||||
):
|
||||
# No more tool calls
|
||||
self.is_tool_call_started = False
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Check if end thinking
|
||||
if not self.is_thinking_end and (
|
||||
self.think_end_token_id in delta_token_ids
|
||||
or self.think_end_token in delta_text
|
||||
):
|
||||
self.is_thinking_end = True
|
||||
|
||||
# If thinking hasn't ended yet, don't process any tool calls
|
||||
if not self.is_thinking_end:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (
|
||||
self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text
|
||||
):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[
|
||||
: delta_text.index(self.tool_call_start_token)
|
||||
]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (
|
||||
current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""
|
||||
):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
# Count tool calls we've seen vs processed
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# We're in a tool call, find the current tool call portion
|
||||
# Need to find the correct tool call based on current_tool_index
|
||||
# Only process tool calls after think_end_token
|
||||
think_end_index = (
|
||||
current_text.find(self.think_end_token) + len(self.think_end_token)
|
||||
if self.think_end_token in current_text
|
||||
else 0
|
||||
)
|
||||
tool_starts: list[int] = []
|
||||
idx = think_end_index
|
||||
while True:
|
||||
idx = current_text.find(self.tool_call_start_token, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
tool_starts.append(idx)
|
||||
idx += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_starts):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_starts[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[
|
||||
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
|
||||
]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix
|
||||
)
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id() # type: ignore
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# This ensures finish_reason="tool_calls" even if parsing isn't complete
|
||||
already_added = any(
|
||||
tool.get("name") == self.current_function_name
|
||||
for tool in self.prev_tool_call_arr
|
||||
)
|
||||
if not already_added:
|
||||
self.prev_tool_call_arr.append(
|
||||
{
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}", # Placeholder, will be updated later
|
||||
}
|
||||
)
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if not self.json_started and self.parameter_prefix not in delta_text:
|
||||
self.json_started = True
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract the complete tool call to update prev_tool_call_arr with final arguments
|
||||
# Find the function content
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix
|
||||
)
|
||||
func_content_end = tool_text.find(self.function_end_token, func_start)
|
||||
if func_content_end != -1:
|
||||
func_content = tool_text[func_start:func_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
func_content, request.tools if request else None
|
||||
)
|
||||
if parsed_tool:
|
||||
# Update existing entry in prev_tool_call_arr with complete arguments
|
||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||
if tool.get("name") == parsed_tool.function.name:
|
||||
self.prev_tool_call_arr[i]["arguments"] = (
|
||||
parsed_tool.function.arguments
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool arguments during streaming.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
result = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Reset state for next tool
|
||||
self.in_function = False
|
||||
self.json_closed = True
|
||||
|
||||
return result
|
||||
|
||||
# Look for parameters
|
||||
# Count how many complete parameters we have processed
|
||||
complete_params = tool_text.count(self.parameter_end_token)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if not self.in_param and self.param_count < complete_params:
|
||||
# Find the unprocessed parameter
|
||||
# Count parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
if len(param_starts) > self.param_count:
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
self.current_param_name = remaining[:name_end]
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Build complete JSON fragment for this parameter
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
'"'
|
||||
+ self.current_param_name
|
||||
+ '": "'
|
||||
+ json.dumps(param_value)[1:-1]
|
||||
+ '"'
|
||||
)
|
||||
else:
|
||||
json_fragment = (
|
||||
', "'
|
||||
+ self.current_param_name
|
||||
+ '": "'
|
||||
+ json.dumps(param_value)[1:-1]
|
||||
+ '"'
|
||||
)
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=json_fragment
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Continue parameter value
|
||||
if self.in_param:
|
||||
if self.parameter_end_token in delta_text:
|
||||
# End of parameter
|
||||
end_idx = delta_text.find(self.parameter_end_token)
|
||||
value_chunk = delta_text[:end_idx]
|
||||
|
||||
# Skip past > if at start
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1 :]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
# Calculate incremental JSON
|
||||
full_value = self.current_param_value + value_chunk
|
||||
prev_escaped = (
|
||||
json.dumps(self.current_param_value)[1:-1]
|
||||
if self.current_param_value
|
||||
else ""
|
||||
)
|
||||
full_escaped = json.dumps(full_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped) :]
|
||||
|
||||
self.in_param = False
|
||||
self.current_param_value = ""
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped + '"'
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Continue accumulating value
|
||||
value_chunk = delta_text
|
||||
|
||||
# Handle first chunk after param name
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1 :]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
if value_chunk:
|
||||
# Stream the escaped delta
|
||||
prev_escaped = (
|
||||
json.dumps(self.current_param_value)[1:-1]
|
||||
if self.current_param_value
|
||||
else ""
|
||||
)
|
||||
self.current_param_value += value_chunk
|
||||
full_escaped = json.dumps(self.current_param_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped) :]
|
||||
|
||||
if delta_escaped:
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
303
vllm/tool_parsers/step3_tool_parser.py
Normal file
303
vllm/tool_parsers/step3_tool_parser.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Step3ToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for a model that uses a specific XML-like format for tool calls.
|
||||
This version uses a robust, stateful, cursor-based streaming parser and
|
||||
consolidates tool arguments into a single message.
|
||||
"""
|
||||
|
||||
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
|
||||
TOOL_CALLS_END = "<|tool_calls_end|>"
|
||||
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
|
||||
TOOL_CALL_END = "<|tool_call_end|>"
|
||||
TOOL_SEP = "<|tool_sep|>"
|
||||
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
# Explicit state flags for robust streaming
|
||||
self.tool_block_started = False
|
||||
self.tool_block_finished = False
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def _parse_steptml_invoke(
|
||||
action_text: str,
|
||||
) -> tuple[str | None, dict[str, str] | None]:
|
||||
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
|
||||
if not func_name_match:
|
||||
return None, None
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
param_matches = re.findall(
|
||||
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
|
||||
action_text,
|
||||
)
|
||||
for name, value in param_matches:
|
||||
params[name] = value.strip()
|
||||
return func_name, params
|
||||
|
||||
def _cast_arguments(
|
||||
self,
|
||||
func_name: str,
|
||||
params: dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict[str, Any]:
|
||||
for tool in request.tools or []:
|
||||
if tool.function.name == func_name:
|
||||
schema = tool.function.parameters or {}
|
||||
properties = schema.get("properties", {})
|
||||
for key, value in params.items():
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
prop = properties.get(key, {})
|
||||
typ = prop.get("type")
|
||||
if typ == "string":
|
||||
params[key] = value.strip()
|
||||
elif typ == "integer":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = int(value)
|
||||
elif typ == "number":
|
||||
with contextlib.suppress(ValueError):
|
||||
params[key] = float(value)
|
||||
elif typ == "boolean":
|
||||
lower_val = value.lower()
|
||||
params[key] = (
|
||||
lower_val == "true"
|
||||
if lower_val in ("true", "false")
|
||||
else value
|
||||
)
|
||||
elif typ == "null":
|
||||
params[key] = None if value.lower() == "null" else value
|
||||
break
|
||||
return params
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# The main loop processes the stream from the last known position.
|
||||
while True:
|
||||
if self.position >= len(current_text):
|
||||
return None # We've processed the entire stream.
|
||||
|
||||
unprocessed_text = current_text[self.position :]
|
||||
|
||||
# STATE: After all tools are done, all subsequent text is content.
|
||||
if self.tool_block_finished:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
|
||||
# STATE: Before the tool block has started.
|
||||
if not self.tool_block_started:
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
|
||||
self.position += len(self.TOOL_CALLS_BEGIN)
|
||||
self.tool_block_started = True
|
||||
continue # Token consumed, re-loop.
|
||||
|
||||
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
|
||||
if start_pos == -1:
|
||||
if (
|
||||
self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
|
||||
and unprocessed_text
|
||||
):
|
||||
return None # It's a prefix, wait.
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=unprocessed_text)
|
||||
else:
|
||||
content = unprocessed_text[:start_pos]
|
||||
self.position += len(content)
|
||||
return DeltaMessage(content=content)
|
||||
|
||||
# STATE: Inside the main tool block.
|
||||
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
|
||||
unprocessed_text = unprocessed_text.lstrip()
|
||||
self.position += offset
|
||||
|
||||
if unprocessed_text.startswith(self.TOOL_CALLS_END):
|
||||
self.position += len(self.TOOL_CALLS_END)
|
||||
self.tool_block_finished = True
|
||||
self.current_tool_id = -1
|
||||
continue
|
||||
|
||||
# Check if we are between tool calls.
|
||||
tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
|
||||
self.current_tool_id
|
||||
].get("finished")
|
||||
if self.current_tool_id == -1 or tool_finished:
|
||||
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
|
||||
self.position += len(self.TOOL_CALL_BEGIN)
|
||||
if self.current_tool_id == -1:
|
||||
self.current_tool_id = 0
|
||||
else:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
|
||||
continue
|
||||
|
||||
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
|
||||
return None
|
||||
|
||||
# STATE: Parsing an active tool call.
|
||||
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
|
||||
self.current_tool_id
|
||||
].get("finished", False):
|
||||
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
|
||||
if end_tool_pos == -1:
|
||||
tool_body = unprocessed_text
|
||||
else:
|
||||
tool_body = unprocessed_text[:end_tool_pos]
|
||||
|
||||
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
|
||||
return None
|
||||
|
||||
function_name, arguments = self._parse_steptml_invoke(tool_body)
|
||||
if not function_name:
|
||||
return None
|
||||
|
||||
tool_call_arr = {"name": function_name, "parameters": arguments or {}}
|
||||
|
||||
# Send the function name as soon as it's parsed.
|
||||
if not self.current_tool_name_sent:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(name=function_name),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Update our internal state with the latest parsed arguments.
|
||||
self.prev_tool_call_arr[self.current_tool_id].update( # noqa: E501
|
||||
tool_call_arr
|
||||
)
|
||||
|
||||
# Only send arguments when the tool call is complete.
|
||||
if end_tool_pos != -1:
|
||||
self.position += end_tool_pos + len(self.TOOL_CALL_END)
|
||||
self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
|
||||
|
||||
final_args = self._cast_arguments(
|
||||
function_name,
|
||||
tool_call_arr.get("parameters", {}), # type: ignore
|
||||
request,
|
||||
)
|
||||
if final_args:
|
||||
final_args_json = json.dumps(final_args, ensure_ascii=False)
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=final_args_json
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# If tool is not finished, return None to wait for more tokens.
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.TOOL_CALLS_BEGIN not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
|
||||
if self.TOOL_CALLS_END not in rest:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
|
||||
content = (pre_text + post_text).strip()
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
|
||||
|
||||
for part in call_parts:
|
||||
if not part or self.TOOL_CALL_END not in part:
|
||||
continue
|
||||
|
||||
call_content = part.split(self.TOOL_CALL_END, 1)[0]
|
||||
if self.TOOL_SEP not in call_content:
|
||||
continue
|
||||
|
||||
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
|
||||
if type_part.strip() != "function":
|
||||
continue
|
||||
|
||||
function_name, params_dict = self._parse_steptml_invoke(invoke_part)
|
||||
|
||||
if function_name and params_dict is not None:
|
||||
params_dict = self._cast_arguments(function_name, params_dict, request)
|
||||
params_str = json.dumps(params_dict, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
function=FunctionCall(name=function_name, arguments=params_str)
|
||||
)
|
||||
)
|
||||
if tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
229
vllm/tool_parsers/utils.py
Normal file
229
vllm/tool_parsers/utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from openai.types.responses import (
|
||||
FunctionTool,
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses.tool import Tool
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, "", 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string: str, substring: str) -> list[int]:
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_tool_info(
|
||||
tool: Tool | ChatCompletionToolsParam,
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.name, tool.parameters
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
return tool.function.name, tool.function.parameters
|
||||
else:
|
||||
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
|
||||
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
|
||||
name, params = _extract_tool_info(tool)
|
||||
params = params if params else {"type": "object", "properties": {}}
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [name]},
|
||||
"parameters": params,
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def _get_tool_schema_defs(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
all_defs: dict[str, dict[str, Any]] = {}
|
||||
for tool in tools:
|
||||
_, params = _extract_tool_info(tool)
|
||||
if params is None:
|
||||
continue
|
||||
defs = params.pop("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has multiple schemas, "
|
||||
"which is not supported."
|
||||
)
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_json_schema_from_tools(
|
||||
tools: list[Tool | ChatCompletionToolsParam],
|
||||
) -> dict:
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_json_schema_from_tools(
|
||||
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
||||
tools: list[FunctionTool | ChatCompletionToolsParam] | None,
|
||||
) -> str | dict | None:
|
||||
# tool_choice: "none"
|
||||
if tool_choice in ("none", None) or tools is None:
|
||||
return None
|
||||
# tool_choice: Forced Function (Responses)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ToolChoiceFunction
|
||||
):
|
||||
tool_name = tool_choice.name
|
||||
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].parameters
|
||||
# tool_choice: Forced Function (ChatCompletion)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
tool_name = tool_choice.function.name
|
||||
tool_map = {
|
||||
tool.function.name: tool
|
||||
for tool in tools
|
||||
if isinstance(tool, ChatCompletionToolsParam)
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].function.parameters
|
||||
# tool_choice: "required"
|
||||
if tool_choice == "required":
|
||||
return _get_json_schema_from_tools(tools)
|
||||
# tool_choice: "auto"
|
||||
return None
|
||||
556
vllm/tool_parsers/xlam_tool_parser.py
Normal file
556
vllm/tool_parsers/xlam_tool_parser.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class xLAMToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
self.prev_tool_calls: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args: list[str] = [] # Track arguments sent for each tool
|
||||
|
||||
# For backward compatibility with tests
|
||||
self.current_tools_sent: list[bool] = []
|
||||
|
||||
# For backward compatibility with serving code
|
||||
self.prev_tool_call_arr = []
|
||||
|
||||
# Regex patterns for preprocessing
|
||||
self.json_code_block_patterns = [
|
||||
r"```(?:json)?\s*([\s\S]*?)```",
|
||||
r"\[TOOL_CALLS\]([\s\S]*?)(?=\n|$)",
|
||||
r"<tool_call>([\s\S]*?)</tool_call>",
|
||||
]
|
||||
self.thinking_tag_pattern = r"</think>([\s\S]*)"
|
||||
|
||||
# Define streaming state type to be initialized later
|
||||
self.streaming_state: dict[str, Any] = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [],
|
||||
}
|
||||
|
||||
def preprocess_model_output(
|
||||
self, model_output: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Preprocess the model output to extract content and potential tool calls.
|
||||
Returns:
|
||||
Tuple of (content, potential_tool_calls_json)
|
||||
"""
|
||||
# Check for thinking tag
|
||||
thinking_match = re.search(self.thinking_tag_pattern, model_output)
|
||||
if thinking_match:
|
||||
content = model_output[: thinking_match.start() + len("</think>")].strip()
|
||||
thinking_content = thinking_match.group(1).strip()
|
||||
|
||||
# Try to parse the thinking content as JSON
|
||||
try:
|
||||
json.loads(thinking_content)
|
||||
return content, thinking_content
|
||||
except json.JSONDecodeError:
|
||||
# If can't parse as JSON, look for JSON code blocks
|
||||
for json_pattern in self.json_code_block_patterns:
|
||||
json_matches = re.findall(json_pattern, thinking_content)
|
||||
if json_matches:
|
||||
for json_str in json_matches:
|
||||
try:
|
||||
json.loads(json_str)
|
||||
return content, json_str
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Check for JSON code blocks in the entire output
|
||||
for json_pattern in self.json_code_block_patterns:
|
||||
json_matches = re.findall(json_pattern, model_output)
|
||||
if json_matches:
|
||||
for json_str in json_matches:
|
||||
try:
|
||||
json.loads(json_str)
|
||||
# Extract content by removing the JSON code block
|
||||
content = re.sub(json_pattern, "", model_output).strip()
|
||||
return content, json_str
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If the entire output is a valid JSON array or looks like one, treat it as tool calls
|
||||
if model_output.strip().startswith("["):
|
||||
try:
|
||||
json.loads(model_output)
|
||||
return None, model_output
|
||||
except json.JSONDecodeError:
|
||||
# Even if it's not valid JSON yet, it might be a tool call in progress
|
||||
if (
|
||||
"{" in model_output
|
||||
and "name" in model_output
|
||||
and "arguments" in model_output
|
||||
):
|
||||
return None, model_output
|
||||
|
||||
# If no tool calls found, return the original output as content
|
||||
return model_output, None
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract tool calls from a complete model output.
|
||||
"""
|
||||
try:
|
||||
# Preprocess the model output
|
||||
content, potential_tool_calls = self.preprocess_model_output(model_output)
|
||||
|
||||
if not potential_tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=content
|
||||
)
|
||||
|
||||
# Parse the potential tool calls as JSON
|
||||
tool_calls_data = json.loads(potential_tool_calls)
|
||||
|
||||
# Ensure it's an array
|
||||
if not isinstance(tool_calls_data, list):
|
||||
logger.debug("Tool calls data is not an array")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=content or model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for idx, call in enumerate(tool_calls_data):
|
||||
if (
|
||||
not isinstance(call, dict)
|
||||
or "name" not in call
|
||||
or "arguments" not in call
|
||||
):
|
||||
logger.debug("Invalid tool call format at index %d", idx)
|
||||
continue
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{idx}_{random_uuid()}",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=call["name"],
|
||||
arguments=(
|
||||
json.dumps(call["arguments"])
|
||||
if isinstance(call["arguments"], dict)
|
||||
else call["arguments"]
|
||||
),
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting tool calls: %s", str(e))
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
# First, check for a definitive start of a tool call block.
|
||||
# This prevents premature parsing of incomplete output.
|
||||
stripped_text = current_text.strip()
|
||||
preprocessed_content, preprocessed_tool_calls = self.preprocess_model_output(
|
||||
current_text
|
||||
)
|
||||
|
||||
# For JSON code blocks, we need to detect them earlier, even if incomplete
|
||||
has_potential_json_block = (
|
||||
"```json" in current_text
|
||||
or "```\n[" in current_text
|
||||
or "[TOOL_CALLS]" in current_text
|
||||
or "<tool_call>" in current_text
|
||||
)
|
||||
|
||||
is_tool_call_block = (
|
||||
stripped_text.startswith("[")
|
||||
or stripped_text.startswith("<tool_call>")
|
||||
or stripped_text.startswith("[TOOL_CALLS]")
|
||||
or
|
||||
# Check if we have thinking tags with JSON-like content following
|
||||
("</think>[" in current_text)
|
||||
or
|
||||
# Check if the text contains a JSON array after preprocessing
|
||||
preprocessed_tool_calls is not None
|
||||
or
|
||||
# For JSON code blocks, detect early if we see enough structure
|
||||
(
|
||||
has_potential_json_block
|
||||
and '"name"' in current_text
|
||||
and '"arguments"' in current_text
|
||||
)
|
||||
)
|
||||
|
||||
if not is_tool_call_block:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# Initialize streaming state if not exists
|
||||
if not hasattr(self, "streaming_state"):
|
||||
self.streaming_state = {
|
||||
"current_tool_index": -1,
|
||||
"tool_ids": [],
|
||||
"sent_tools": [], # Track complete state of each tool
|
||||
}
|
||||
|
||||
# Try parsing as JSON to check for complete tool calls
|
||||
try:
|
||||
# Use preprocessed tool calls if available
|
||||
tool_calls_text = (
|
||||
preprocessed_tool_calls if preprocessed_tool_calls else current_text
|
||||
)
|
||||
parsed_tools = json.loads(tool_calls_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
# Update our tool array for next time
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
except json.JSONDecodeError:
|
||||
# Not complete JSON yet, use regex for partial parsing
|
||||
pass
|
||||
|
||||
# Check for test-specific state setup (current_tools_sent)
|
||||
# This handles the case where tests manually set current_tools_sent
|
||||
if (
|
||||
hasattr(self, "current_tools_sent") # type: ignore
|
||||
and len(self.current_tools_sent) > 0
|
||||
):
|
||||
# If current_tools_sent is set to [False], it means the test wants us to send the name
|
||||
if (
|
||||
len(self.current_tools_sent) == 1
|
||||
and self.current_tools_sent[0] is False
|
||||
):
|
||||
# Extract the function name using regex
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_match = re.search(name_pattern, current_text)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
|
||||
# The test expects us to send just the name first
|
||||
tool_id = make_tool_call_id()
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True), # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
# Update state to reflect that we've sent the name
|
||||
self.current_tools_sent = [True]
|
||||
self.current_tool_id = 0
|
||||
self.streaming_state["current_tool_index"] = 0
|
||||
if len(self.streaming_state["sent_tools"]) == 0:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": True,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.streaming_state["sent_tools"][0]["sent_name"] = True
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
|
||||
# Use regex to identify tool calls in the output
|
||||
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
|
||||
search_text = (
|
||||
preprocessed_tool_calls if preprocessed_tool_calls else current_text
|
||||
)
|
||||
|
||||
# For JSON code blocks that aren't complete yet, try to extract the JSON content
|
||||
if not preprocessed_tool_calls and has_potential_json_block:
|
||||
# Try to extract the JSON array from within the code block
|
||||
json_match = re.search(
|
||||
r"```(?:json)?\s*([\s\S]*?)(?:```|$)", current_text
|
||||
)
|
||||
if json_match:
|
||||
potential_json = json_match.group(1).strip()
|
||||
# Use this as search text even if it's incomplete
|
||||
if potential_json.startswith("[") and (
|
||||
'"name"' in potential_json and '"arguments"' in potential_json
|
||||
):
|
||||
search_text = potential_json
|
||||
|
||||
# Try to find complete tool names first
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_matches = list(re.finditer(name_pattern, search_text))
|
||||
tool_count = len(name_matches)
|
||||
|
||||
# If no complete tool names found, check for partial tool names
|
||||
if tool_count == 0:
|
||||
# Check if we're in the middle of parsing a tool name
|
||||
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
|
||||
partial_matches = list(re.finditer(partial_name_pattern, search_text))
|
||||
if partial_matches:
|
||||
# We have a partial tool name - not ready to emit yet
|
||||
return None
|
||||
else:
|
||||
# No tools found at all
|
||||
return None
|
||||
|
||||
# Ensure our state arrays are large enough
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
self.streaming_state["sent_tools"].append(
|
||||
{
|
||||
"sent_name": False,
|
||||
"sent_arguments_prefix": False,
|
||||
"sent_arguments": "",
|
||||
}
|
||||
)
|
||||
|
||||
while len(self.streaming_state["tool_ids"]) < tool_count:
|
||||
self.streaming_state["tool_ids"].append(None)
|
||||
|
||||
# Determine if we need to move to a new tool
|
||||
current_idx = self.streaming_state["current_tool_index"]
|
||||
|
||||
# If we haven't processed any tool yet or current tool is complete, move to next
|
||||
if current_idx == -1 or current_idx < tool_count - 1:
|
||||
next_idx = current_idx + 1
|
||||
|
||||
# If tool at next_idx has not been sent yet
|
||||
if (
|
||||
next_idx < tool_count
|
||||
and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
|
||||
):
|
||||
# Update indexes
|
||||
self.streaming_state["current_tool_index"] = next_idx
|
||||
self.current_tool_id = next_idx # For backward compatibility
|
||||
current_idx = next_idx
|
||||
|
||||
# Extract the tool name
|
||||
tool_name = name_matches[current_idx].group(1)
|
||||
|
||||
# Generate ID and send tool name
|
||||
tool_id = f"call_{current_idx}_{random_uuid()}"
|
||||
self.streaming_state["tool_ids"][current_idx] = tool_id
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(name=tool_name).model_dump(
|
||||
exclude_none=True
|
||||
), # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
|
||||
self.current_tool_name_sent = True # For backward compatibility
|
||||
|
||||
# Keep track of streamed args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
|
||||
return delta
|
||||
|
||||
# Process arguments for the current tool
|
||||
if current_idx >= 0 and current_idx < tool_count:
|
||||
# Support both regular and empty argument objects
|
||||
# First, check for the empty arguments case: "arguments": {}
|
||||
empty_args_pattern = (
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
|
||||
)
|
||||
empty_args_match = re.search(empty_args_pattern, search_text)
|
||||
|
||||
# Check if this tool has empty arguments
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
# Find which tool this empty arguments belongs to
|
||||
empty_args_tool_idx = 0
|
||||
for i in range(tool_count):
|
||||
if i == current_idx:
|
||||
# If this is our current tool and it has empty arguments
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
]:
|
||||
# Send empty object
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{}"
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{}"
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{}"
|
||||
).model_dump(exclude_none=True), # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Move to next tool if available
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
]
|
||||
|
||||
return delta
|
||||
|
||||
# Extract arguments for current tool using regex for non-empty arguments
|
||||
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
args_matches = list(re.finditer(args_pattern, search_text))
|
||||
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
|
||||
# Handle transition between tools
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
|
||||
# For multiple tools, extract only the arguments for the current tool
|
||||
if tool_count > 1:
|
||||
# Parse the entire JSON structure to properly extract arguments for each tool
|
||||
try:
|
||||
parsed_tools = json.loads(search_text)
|
||||
if isinstance(parsed_tools, list) and current_idx < len(
|
||||
parsed_tools
|
||||
):
|
||||
current_tool = parsed_tools[current_idx]
|
||||
if isinstance(current_tool.get("arguments"), dict):
|
||||
args_text = json.dumps(current_tool["arguments"])
|
||||
else:
|
||||
args_text = str(current_tool.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
# Fallback to regex-based extraction
|
||||
pass
|
||||
|
||||
# If arguments haven't been sent yet
|
||||
sent_args = self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
]
|
||||
|
||||
# If we haven't sent the opening bracket yet
|
||||
if not self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] and args_text.startswith("{"):
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments_prefix"
|
||||
] = True
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = "{"
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += "{"
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{"
|
||||
).model_dump(exclude_none=True), # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
# If we need to send more arguments
|
||||
if args_text.startswith(sent_args):
|
||||
# Calculate what part of arguments we need to send
|
||||
args_diff = args_text[len(sent_args) :]
|
||||
|
||||
if args_diff:
|
||||
# Update our state
|
||||
self.streaming_state["sent_tools"][current_idx][
|
||||
"sent_arguments"
|
||||
] = args_text
|
||||
|
||||
# Update streamed_args for backward compatibility
|
||||
while len(self.streamed_args) <= current_idx:
|
||||
self.streamed_args.append("")
|
||||
self.streamed_args[current_idx] += args_diff
|
||||
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=args_diff
|
||||
).model_dump(exclude_none=True), # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
|
||||
# If the tool's arguments are complete, check if we need to move to the next tool
|
||||
if args_text.endswith("}") and args_text == sent_args:
|
||||
# This tool is complete, move to the next one in the next iteration
|
||||
if current_idx < tool_count - 1:
|
||||
self.streaming_state["current_tool_index"] += 1
|
||||
self.current_tool_id = self.streaming_state[
|
||||
"current_tool_index"
|
||||
] # For compatibility
|
||||
|
||||
# If we got here, we couldn't determine what to stream next
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in streaming tool calls: {e}")
|
||||
# If we encounter an error, just return the delta text as regular content
|
||||
return DeltaMessage(content=delta_text)
|
||||
Reference in New Issue
Block a user