[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser"
]

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Sequence
from functools import cached_property
from typing import Callable, Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import import_from_path, is_list_of
logger = init_logger(__name__)
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls has not been implemented!")
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")
class ToolParserManager:
tool_parsers: dict[str, type] = {}
@classmethod
def get_tool_parser(cls, name) -> type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ToolParser):
raise TypeError(
f'module must be subclass of ToolParser, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f'{name} is already registered '
f'at {existed_module.__module__}')
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, list[str]]] = None,
force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user-defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
logger.exception("Failed to load module '%s' from %s.",
module_name, plugin_path)
return

View File

@@ -0,0 +1,370 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("deepseek_v3")
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = (
[]) # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool▁calls▁begin>"
self.tool_calls_end_token: str = "<tool▁calls▁end>"
self.tool_call_start_token: str = "<tool▁call▁begin>"
self.tool_call_end_token: str = "<tool▁call▁end>"
self.tool_call_regex = re.compile(
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>"
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])"
)
self.stream_tool_call_name_regex = re.compile(
r"(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n")
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None):
raise RuntimeError(
"DeepSeek-V3 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(
model_output)
tool_calls = []
for match in function_call_tuples:
tool_type, function_name, function_args = match
tool_calls.append(
ToolCall(
type=tool_type,
function=FunctionCall(name=function_name,
arguments=function_args),
))
content = model_output[:model_output.
find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_calls_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
delta_text = delta_text.replace(self.tool_calls_start_token,
"").replace(self.tool_calls_end_token,
"")
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count):
if self.prev_tool_call_arr is None or len(
self.prev_tool_call_arr) == 0:
logger.debug(
"attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = (diff.encode("utf-8").decode("unicode_escape")
if diff is str else diff)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(exclude_none=True),
)
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = (
self.stream_tool_call_portion_regex.match(
tool_call_portion))
if current_tool_call_matches:
tool_type, tool_name, tool_args = (
current_tool_call_matches.groups())
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(
tool_call_portion))
if current_tool_call_name_matches:
tool_type, tool_name = (
current_tool_call_name_matches.groups())
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True),
)
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (DeltaMessage(
content=delta_text) if text_portion is not None else None)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments).model_dump(
exclude_none=True),
)
])
self.streamed_args_for_tool[
self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)):
delta_arguments = cur_arguments[len(prev_arguments):]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments).model_dump(
exclude_none=True),
)
])
self.streamed_args_for_tool[
self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[
self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from json import JSONDecoder
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("granite-20b-fc")
class Granite20bFCToolParser(ToolParser):
"""
Tool call parser for the granite-20b-functioncalling model intended
for use with the examples/tool_chat_template_granite20b_fc.jinja
template.
Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.bot_token = "<function_call>"
self.tool_start_token = self.bot_token
self.tool_call_regex = re.compile(r"<function_call>\s*")
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
if self.tool_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
dec = JSONDecoder()
try:
matches = list(self.tool_call_regex.finditer(model_output))
logger.debug("Found %d tool call matches", len(matches))
raw_function_calls = []
for i, match in enumerate(matches):
# position after the <function_call> tag
start_of_json = match.end()
# end_index == the start of the next function call
# (if exists)
next_function_call_start = (matches[i + 1].start() if i +
1 < len(matches) else None)
raw_function_calls.append(
dec.raw_decode(
model_output[start_of_json:next_function_call_start])
[0])
logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
),
) for function_call in raw_function_calls
]
content = model_output[:model_output.find(self.bot_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if len(current_text) < len(
self.bot_token) and self.bot_token.startswith(current_text):
return None
if not current_text.startswith(self.bot_token):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = len(self.bot_token)
start_idx = consume_space(start_idx, current_text)
while start_idx < len(current_text):
(obj,
end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx +
end_idx]))
start_idx += end_idx
start_idx = consume_space(start_idx, current_text)
start_idx += len(self.bot_token)
start_idx = consume_space(start_idx, current_text)
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("granite")
class GraniteToolParser(ToolParser):
"""
Tool call parser for the granite 3.0 models. Intended
for use with the examples/tool_chat_template_granite.jinja
template.
Used when --enable-auto-tool-choice --tool-call-parser granite
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
self.bot_string = "<tool_call>"
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
stripped = model_output.strip()\
.removeprefix(self.bot_token)\
.removeprefix(self.bot_string)\
.lstrip()
if not stripped or stripped[0] != '[':
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
raw_function_calls = json.loads(stripped)
if not isinstance(raw_function_calls, list):
raise Exception(
f"Expected dict or list, got {type(raw_function_calls)}")
logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
),
) for function_call in raw_function_calls
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=None,
)
except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
start_idx = consume_space(0, current_text)
if current_text[start_idx:].startswith(self.bot_token):
start_idx = consume_space(start_idx + len(self.bot_token),
current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string),
current_text)
if not current_text or start_idx >= len(current_text)\
or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = None
is_complete = None
try:
tool_calls, end_idx = partial_json_loads(
current_text[start_idx:], flags)
if type(tool_calls) is list:
tool_call_arr = tool_calls
else:
return DeltaMessage(content=delta_text)
is_complete = [True] * len(tool_calls)
if not is_complete_json(
current_text[start_idx:start_idx + end_idx]):
is_complete[-1] = False
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if not tool_call_arr:
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id]
delta = None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
if len(tool_call_arr) > self.current_tool_id + 1:
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,371 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = (
self.tool_call_regex.findall(model_output))
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
json.loads(match[0] if match[0] else match[1])
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False)))
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count):
if (self.prev_tool_call_arr is None
or len(self.prev_tool_call_arr) == 0):
logger.debug(
"attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff
if ('"}' not in delta_text):
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
self.streamed_args_for_tool[self.current_tool_id] \
+= diff
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
try:
current_tool_call = partial_json_parser.loads(
tool_call_portion or "{}",
flags) if tool_call_portion else None
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
except json.decoder.JSONDecodeError:
logger.debug("unable to parse JSON")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if (current_tool_call is None):
return None
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = DeltaMessage(content=delta_text) \
if text_portion is not None else None
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = (
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
if (delta_text not in cur_arguments_json[:-2]):
return None
args_delta_start_loc = cur_arguments_json[:-2]. \
rindex(delta_text) + \
len(delta_text)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip(
)) >= 1 and delta_text.rstrip()[-1] == '}':
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_text).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= delta_text
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = \
current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module(["internlm"])
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_argments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if '<|action_start|>' not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message
# to make sure the finish_reason will be send correctly.
if self.current_tool_id > 0:
return DeltaMessage(content='')
last_pos = self.position
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split('<|action_start|><|plugin|>')
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split('<|action_end|>'.strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_argments(
self.prev_tool_call_arr[self.current_tool_id])
cur_arguments = self.get_argments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
len(delta_text)]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if '<|action_start|><|plugin|>' in text:
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})),
ensure_ascii=False)
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
tool_calls = [
ToolCall(
function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("jamba")
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
raise ValueError(
"Detected a MistralTokenizer tokenizer when using a Jamba model"
)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool_calls>"
self.tool_calls_end_token: str = "</tool_calls>"
self.tool_calls_regex = re.compile(
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token)
if (self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None):
raise RuntimeError(
"Jamba Tool parser could not locate tool calls start/end "
"tokens in the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because jamba use the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# use a regex to find the tool call between the tags
function_calls = self.tool_calls_regex.findall(model_output)[0]
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = json.loads(function_calls)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False),
)) for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if
(len(content) > 0 and content != " ") else None)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.tool_calls_start_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the start of tool calls token which means
# the start of tool calling
if (self.tool_calls_start_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion and don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# Extract the tool calls between the special tool call tokens
parsable_arr = current_text.split(
self.tool_calls_start_token)[-1].split(
self.tool_calls_end_token)[0]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("llama4_pythonic")
class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if model_output.startswith("<|python_start|>"):
model_output = model_output[len("<|python_start|>"):]
model_output = model_output.replace("<|python_end|>", "")
is_tool_call_pattern = False
try:
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
model_output,
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
except TimeoutError:
logger.warning(
"Regex timeout occurred when matching tool call pattern.")
logger.debug("Regex timeout occurred when matching user input: %s",
model_output)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.startswith("[") and not current_text.startswith(
"<|python_start|>"):
return DeltaMessage(content=delta_text)
try:
# remove <|python_start|> and <|python_end|>
if current_text.startswith("<|python_start|>"):
current_text = current_text[len("<|python_start|>"):]
if current_text.endswith("<|python_end|>"):
current_text = current_text[:current_text.
rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts):
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(
tool_calls) - 1 or ")]" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = (added_text[:-2]
if not new_call_complete else "")
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
new_call, index, withheld_suffix)
if delta is not None:
tool_deltas.append(delta)
if (delta.function is not None
and delta.function.arguments is not None):
self.streamed_args_for_tool[
index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content='')
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError(
"Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[:text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[:text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
"[") and not text.endswith(")"):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
index: int,
withheld_suffix: str) -> Union[DeltaToolCall, None]:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
))
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None

View File

@@ -0,0 +1,267 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from json import JSONDecoder
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ToolParserManager.register_module("llama3_json")
@ToolParserManager.register_module("llama4_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser llama3_json
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token,
add_special_tokens=False)[0]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# case -- if a tool call token is not present, return a text response
if not (model_output.startswith(self.bot_token)
or model_output.startswith('{')):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# load the JSON, and then use it to build the Function and
# Tool Call
dec = JSONDecoder()
function_call_arr = []
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if model_output.startswith(
self.bot_token) else 0
while start_idx < len(model_output):
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
start_idx += end_idx + len('; ')
function_call_arr.append(obj)
tool_calls: list[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"],
ensure_ascii=False)))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not (current_text.startswith(self.bot_token)
or current_text.startswith('{')):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if current_text.startswith(
self.bot_token) else 0
while start_idx < len(current_text):
(obj,
end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx +
end_idx]))
start_idx += end_idx + len('; ')
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert "arguments" not in obj, \
"model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,369 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from random import choices
from string import ascii_letters, digits
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
return isinstance(model_tokenizer, MistralTokenizer) \
and model_tokenizer.version >= 11
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
re.DOTALL)
else:
self.fn_name_regex = None
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if not isinstance(
self.model_tokenizer, MistralTokenizer
) and request.tools and request.tool_choice != 'none':
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append({
"name": fn_name,
"arguments": json.loads(args)
})
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"],
ensure_ascii=False)))
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=tool_content)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if (self.bot_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=MistralToolCall.generate_random_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if ('"}' in new_text):
new_text = new_text[:new_text.rindex('"}')]
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)[:-2]
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
if (new_text not in cur_arguments_json):
return None
arguments_delta = cur_arguments_json[:cur_arguments_json.
rindex(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Any, Optional
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ToolParserManager.register_module("phi4_mini_json")
class Phi4MiniJsonToolParser(ToolParser):
"""
Tool call parser for phi-4-mini models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token: str = "functools"
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
logger.debug("Model output: %s", model_output)
pattern = r'functools\[(.*?)\]'
matches = re.search(pattern, model_output, re.DOTALL)
if not matches:
logger.debug("No function calls found")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_call_arr: list[dict[str, Any]] = []
try:
json_content = '[' + matches.group(1) + ']'
function_call_arr = json.loads(json_content)
logger.debug("Successfully extracted %d function calls",
len(function_call_arr))
except json.JSONDecodeError as e:
logger.error(
"Failed to parse function calls from model output. "
"Error: %s", str(e))
tool_calls: list[ToolCall] = [
ToolCall(
id=random_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"]
if "arguments" in raw_function_call else
raw_function_call["parameters"],
ensure_ascii=False),
)) for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Optional[DeltaMessage]:
return None

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any, Union
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("pythonic")
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 and Llama 4 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
is_tool_call_pattern = False
try:
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
model_output,
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
except TimeoutError:
logger.warning(
"Regex timeout occurred when matching tool call pattern.")
logger.debug("Regex timeout occurred when matching user input: %s",
model_output)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.startswith("["):
return DeltaMessage(content=delta_text)
try:
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts):
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(
tool_calls) - 1 or ")]" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = (added_text[:-2]
if not new_call_complete else "")
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
new_call, index, withheld_suffix)
if delta is not None:
tool_deltas.append(delta)
if (delta.function is not None
and delta.function.arguments is not None):
self.streamed_args_for_tool[
index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content='')
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError(
"Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments,
ensure_ascii=False)),
)
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[:text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[:text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
"[") and not text.endswith(")"):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
index: int,
withheld_suffix: str) -> Union[DeltaToolCall, None]:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
))
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None

View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any
import partial_json_parser
from partial_json_parser.core.options import Allow
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ''
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ''
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, '', 1)
return diff
def find_all_indices(string: str, substring: str) -> list[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices
# partial_json_parser doesn't support extra data and
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
def consume_space(i: int, s: str) -> int:
while i < len(s) and s[i].isspace():
i += 1
return i