refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor parse_streaming_increment (#6715)

This commit is contained in:
Chang Su
2025-05-29 00:08:45 -07:00
committed by GitHub
parent f4d4f93928
commit c673727e0e
7 changed files with 366 additions and 86 deletions

View File

@@ -36,6 +36,7 @@ class BaseFormatDetector(ABC):
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
self.tool_call_separator = ", "
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
@@ -50,7 +51,7 @@ class BaseFormatDetector(ABC):
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
@@ -106,7 +107,17 @@ class BaseFormatDetector(ABC):
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
# The current_text has tool_call if it is the start of a new tool call sequence
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
if not (
self.bot_token in current_text
or current_text.startswith("{")
or (
self.current_tool_id > 0
and current_text.startswith(self.tool_call_separator + "{")
)
):
# Only clear buffer if we're sure no tool call is starting
if not self._ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer
@@ -127,91 +138,73 @@ class BaseFormatDetector(ABC):
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
if current_text.startswith(self.bot_token):
start_idx = len(self.bot_token)
elif self.current_tool_id > 0 and current_text.startswith(
self.tool_call_separator
):
start_idx = len(self.tool_call_separator)
else:
start_idx = 0
if start_idx >= len(current_text):
return StreamingParseResult()
(obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
is_current_complete = _is_complete_json(
current_text[start_idx : start_idx + end_idx]
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
# Handle parameters/arguments consistency
# NOTE: we assume here that the obj is always partial of a single tool call
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
current_tool_call = obj
except MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
if not current_tool_call:
return StreamingParseResult()
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res
# Handle tool name
elif not self.current_tool_name_sent:
# Case 1: Handle tool name streaming
# This happens when we encounter a tool but haven't sent its name yet
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
# If this is a new tool (current_tool_id was -1), initialize it
if self.current_tool_id == -1:
self.current_tool_id = 0
self.streamed_args_for_tool.append("")
# If this is a subsequent tool, ensure streamed_args_for_tool is large enough
elif self.current_tool_id >= len(self.streamed_args_for_tool):
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Send the tool name with empty parameters
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
@@ -221,47 +214,75 @@ class BaseFormatDetector(ABC):
else:
res = StreamingParseResult()
# Handle streaming arguments
# Case 2: Handle streaming arguments
# This happens when we've already sent the tool name and now need to stream arguments incrementally
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
# Calculate how much of the arguments we've already streamed
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
prev_arguments = None
if self.current_tool_id < len(self.prev_tool_call_arr):
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id
].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
# If the current tool's JSON is complete, send all remaining arguments
if is_current_complete:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
completing_tool_id = (
self.current_tool_id
) # Save the ID of the tool that's completing
# Only remove the processed portion, keep unprocessed content
self._buffer = current_text[start_idx + end_idx :]
if self.current_tool_id < len(self.prev_tool_call_arr):
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
self.current_tool_id += 1
# If the tool is still being parsed, send incremental changes
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
# Send the argument diff if there's something new
if argument_diff is not None:
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
tool_index_to_use = (
completing_tool_id
if is_current_complete
else self.current_tool_id
)
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
tool_index=tool_index_to_use,
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
if not is_current_complete:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
# Update prev_tool_call_arr with current state
if self.current_tool_id >= 0:
# Ensure prev_tool_call_arr is large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
return res
except Exception as e:

View File

@@ -24,6 +24,11 @@ class Llama32Detector(BaseFormatDetector):
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
# NOTE: technically Llama3.2 doesn't support well with parallel tool calls
# They need specific prompt engineering to support parallel tool calls
# Here we use ';' as the separator, which might have compatibility issues
# if users define to use a different separator in their prompt
self.tool_call_separator = ";"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
@@ -42,7 +47,11 @@ class Llama32Detector(BaseFormatDetector):
normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
json_parts = [
part.strip()
for part in action_text.split(self.tool_call_separator)
if part.strip()
]
all_actions = []
for part in json_parts:
try:
@@ -70,5 +79,5 @@ class Llama32Detector(BaseFormatDetector):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
tool_call_separator=",",
tool_call_separator=self.tool_call_separator,
)

View File

@@ -30,6 +30,7 @@ class MistralDetector(BaseFormatDetector):
self.bot_token = "[TOOL_CALLS] ["
self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
self.tool_call_separator = ", "
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
@@ -126,5 +127,5 @@ class MistralDetector(BaseFormatDetector):
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
function_format="json",
tool_call_separator=", ",
tool_call_separator=self.tool_call_separator,
)

View File

@@ -29,6 +29,7 @@ class Qwen25Detector(BaseFormatDetector):
super().__init__()
self.bot_token = "<tool_call>\n"
self.eot_token = "\n</tool_call>"
self.tool_call_separator = "\n"
self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool:
@@ -104,7 +105,6 @@ class Qwen25Detector(BaseFormatDetector):
return result
def structure_info(self) -> _GetInfoFunc:
# TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo(
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
end="}\n</tool_call>",

View File

@@ -18,6 +18,23 @@ def _find_common_prefix(s1: str, s2: str) -> str:
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
Parse incomplete or partial JSON strings commonly encountered during streaming.
Args:
input_str (str): The potentially incomplete JSON string to parse.
flags (Allow): Bitwise flags controlling what types of partial data are allowed.
Common flags include:
- Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
- Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
- Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
- Allow.ALL: Allow all types of partial data
Returns:
Tuple[Any, int]: A tuple containing:
- parsed_object: The Python object parsed from the JSON
- consumed_length: Number of characters consumed from input_str
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:

View File

@@ -1327,7 +1327,6 @@ def v1_chat_generate_response(
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
index=call_info.tool_index,
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),