refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor parse_streaming_increment (#6715)
This commit is contained in:
@@ -36,6 +36,7 @@ class BaseFormatDetector(ABC):
|
|||||||
) # map what has been streamed for each tool so far to a list
|
) # map what has been streamed for each tool so far to a list
|
||||||
self.bot_token = ""
|
self.bot_token = ""
|
||||||
self.eot_token = ""
|
self.eot_token = ""
|
||||||
|
self.tool_call_separator = ", "
|
||||||
|
|
||||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||||
tool_indices = {
|
tool_indices = {
|
||||||
@@ -50,7 +51,7 @@ class BaseFormatDetector(ABC):
|
|||||||
if name and name in tool_indices:
|
if name and name in tool_indices:
|
||||||
results.append(
|
results.append(
|
||||||
ToolCallItem(
|
ToolCallItem(
|
||||||
tool_index=tool_indices[name],
|
tool_index=-1, # Caller should update this based on the actual tools array called
|
||||||
name=name,
|
name=name,
|
||||||
parameters=json.dumps(
|
parameters=json.dumps(
|
||||||
act.get("parameters") or act.get("arguments", {}),
|
act.get("parameters") or act.get("arguments", {}),
|
||||||
@@ -106,7 +107,17 @@ class BaseFormatDetector(ABC):
|
|||||||
# Append new text to buffer
|
# Append new text to buffer
|
||||||
self._buffer += new_text
|
self._buffer += new_text
|
||||||
current_text = self._buffer
|
current_text = self._buffer
|
||||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
|
||||||
|
# The current_text has tool_call if it is the start of a new tool call sequence
|
||||||
|
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
|
||||||
|
if not (
|
||||||
|
self.bot_token in current_text
|
||||||
|
or current_text.startswith("{")
|
||||||
|
or (
|
||||||
|
self.current_tool_id > 0
|
||||||
|
and current_text.startswith(self.tool_call_separator + "{")
|
||||||
|
)
|
||||||
|
):
|
||||||
# Only clear buffer if we're sure no tool call is starting
|
# Only clear buffer if we're sure no tool call is starting
|
||||||
if not self._ends_with_partial_token(self._buffer, self.bot_token):
|
if not self._ends_with_partial_token(self._buffer, self.bot_token):
|
||||||
normal_text = self._buffer
|
normal_text = self._buffer
|
||||||
@@ -127,91 +138,73 @@ class BaseFormatDetector(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_call_arr = []
|
|
||||||
is_complete = []
|
|
||||||
try:
|
try:
|
||||||
start_idx = (
|
if current_text.startswith(self.bot_token):
|
||||||
len(self.bot_token)
|
start_idx = len(self.bot_token)
|
||||||
if current_text.startswith(self.bot_token)
|
elif self.current_tool_id > 0 and current_text.startswith(
|
||||||
else 0
|
self.tool_call_separator
|
||||||
|
):
|
||||||
|
start_idx = len(self.tool_call_separator)
|
||||||
|
else:
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
if start_idx >= len(current_text):
|
||||||
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
(obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
|
||||||
|
|
||||||
|
is_current_complete = _is_complete_json(
|
||||||
|
current_text[start_idx : start_idx + end_idx]
|
||||||
)
|
)
|
||||||
while start_idx < len(current_text):
|
|
||||||
(obj, end_idx) = _partial_json_loads(
|
|
||||||
current_text[start_idx:], flags
|
|
||||||
)
|
|
||||||
is_complete.append(
|
|
||||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
|
||||||
)
|
|
||||||
start_idx += end_idx + len("; ")
|
|
||||||
|
|
||||||
# Validate tool name if present
|
# Validate tool name if present
|
||||||
if "name" in obj and obj["name"] not in self._tool_indices:
|
if "name" in obj and obj["name"] not in self._tool_indices:
|
||||||
# Invalid tool name - reset state
|
# Invalid tool name - reset state
|
||||||
self._buffer = ""
|
self._buffer = ""
|
||||||
self.current_tool_id = -1
|
self.current_tool_id = -1
|
||||||
self.current_tool_name_sent = False
|
self.current_tool_name_sent = False
|
||||||
if self.streamed_args_for_tool:
|
if self.streamed_args_for_tool:
|
||||||
self.streamed_args_for_tool.pop()
|
self.streamed_args_for_tool.pop()
|
||||||
return StreamingParseResult()
|
return StreamingParseResult()
|
||||||
|
|
||||||
# Handle parameters/arguments consistency
|
# Handle parameters/arguments consistency
|
||||||
if "parameters" in obj:
|
# NOTE: we assume here that the obj is always partial of a single tool call
|
||||||
assert (
|
if "parameters" in obj:
|
||||||
"arguments" not in obj
|
assert (
|
||||||
), "model generated both parameters and arguments"
|
"arguments" not in obj
|
||||||
obj["arguments"] = obj["parameters"]
|
), "model generated both parameters and arguments"
|
||||||
tool_call_arr.append(obj)
|
obj["arguments"] = obj["parameters"]
|
||||||
|
|
||||||
|
current_tool_call = obj
|
||||||
|
|
||||||
except MalformedJSON:
|
except MalformedJSON:
|
||||||
return StreamingParseResult()
|
return StreamingParseResult()
|
||||||
|
|
||||||
if len(tool_call_arr) == 0:
|
if not current_tool_call:
|
||||||
return StreamingParseResult()
|
return StreamingParseResult()
|
||||||
|
|
||||||
current_tool_call: Dict = (
|
# Case 1: Handle tool name streaming
|
||||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
# This happens when we encounter a tool but haven't sent its name yet
|
||||||
)
|
if not self.current_tool_name_sent:
|
||||||
|
|
||||||
# Handle new tool in array
|
|
||||||
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
|
||||||
if self.current_tool_id >= 0:
|
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
|
||||||
if cur_arguments:
|
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
|
||||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
|
||||||
argument_diff = cur_args_json[sent:]
|
|
||||||
|
|
||||||
res = StreamingParseResult(
|
|
||||||
calls=[
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=self.current_tool_id,
|
|
||||||
name="",
|
|
||||||
parameters=argument_diff,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.streamed_args_for_tool[
|
|
||||||
self.current_tool_id
|
|
||||||
] += argument_diff
|
|
||||||
else:
|
|
||||||
res = StreamingParseResult()
|
|
||||||
else:
|
|
||||||
res = StreamingParseResult()
|
|
||||||
|
|
||||||
self.current_tool_id = len(tool_call_arr) - 1
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool.append("")
|
|
||||||
return res
|
|
||||||
|
|
||||||
# Handle tool name
|
|
||||||
elif not self.current_tool_name_sent:
|
|
||||||
function_name = current_tool_call.get("name")
|
function_name = current_tool_call.get("name")
|
||||||
|
|
||||||
if function_name and function_name in self._tool_indices:
|
if function_name and function_name in self._tool_indices:
|
||||||
|
# If this is a new tool (current_tool_id was -1), initialize it
|
||||||
|
if self.current_tool_id == -1:
|
||||||
|
self.current_tool_id = 0
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
# If this is a subsequent tool, ensure streamed_args_for_tool is large enough
|
||||||
|
elif self.current_tool_id >= len(self.streamed_args_for_tool):
|
||||||
|
while len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
|
# Send the tool name with empty parameters
|
||||||
res = StreamingParseResult(
|
res = StreamingParseResult(
|
||||||
calls=[
|
calls=[
|
||||||
ToolCallItem(
|
ToolCallItem(
|
||||||
tool_index=self._tool_indices[function_name],
|
tool_index=self.current_tool_id,
|
||||||
name=function_name,
|
name=function_name,
|
||||||
parameters="",
|
parameters="",
|
||||||
)
|
)
|
||||||
@@ -221,47 +214,75 @@ class BaseFormatDetector(ABC):
|
|||||||
else:
|
else:
|
||||||
res = StreamingParseResult()
|
res = StreamingParseResult()
|
||||||
|
|
||||||
# Handle streaming arguments
|
# Case 2: Handle streaming arguments
|
||||||
|
# This happens when we've already sent the tool name and now need to stream arguments incrementally
|
||||||
else:
|
else:
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
res = StreamingParseResult()
|
res = StreamingParseResult()
|
||||||
|
|
||||||
if cur_arguments:
|
if cur_arguments:
|
||||||
|
# Calculate how much of the arguments we've already streamed
|
||||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
prev_arguments = None
|
||||||
"arguments"
|
if self.current_tool_id < len(self.prev_tool_call_arr):
|
||||||
)
|
prev_arguments = self.prev_tool_call_arr[
|
||||||
|
self.current_tool_id
|
||||||
|
].get("arguments")
|
||||||
|
|
||||||
argument_diff = None
|
argument_diff = None
|
||||||
if is_complete[self.current_tool_id]:
|
|
||||||
|
# If the current tool's JSON is complete, send all remaining arguments
|
||||||
|
if is_current_complete:
|
||||||
argument_diff = cur_args_json[sent:]
|
argument_diff = cur_args_json[sent:]
|
||||||
self._buffer = ""
|
completing_tool_id = (
|
||||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
self.current_tool_id
|
||||||
|
) # Save the ID of the tool that's completing
|
||||||
|
|
||||||
|
# Only remove the processed portion, keep unprocessed content
|
||||||
|
self._buffer = current_text[start_idx + end_idx :]
|
||||||
|
|
||||||
|
if self.current_tool_id < len(self.prev_tool_call_arr):
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||||
self.current_tool_name_sent = False
|
self.current_tool_name_sent = False
|
||||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||||
|
self.current_tool_id += 1
|
||||||
|
|
||||||
|
# If the tool is still being parsed, send incremental changes
|
||||||
elif prev_arguments:
|
elif prev_arguments:
|
||||||
prev_args_json = json.dumps(prev_arguments)
|
prev_args_json = json.dumps(prev_arguments)
|
||||||
if cur_args_json != prev_args_json:
|
if cur_args_json != prev_args_json:
|
||||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||||
argument_diff = prefix[sent:]
|
argument_diff = prefix[sent:]
|
||||||
|
|
||||||
|
# Send the argument diff if there's something new
|
||||||
if argument_diff is not None:
|
if argument_diff is not None:
|
||||||
|
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
|
||||||
|
tool_index_to_use = (
|
||||||
|
completing_tool_id
|
||||||
|
if is_current_complete
|
||||||
|
else self.current_tool_id
|
||||||
|
)
|
||||||
res = StreamingParseResult(
|
res = StreamingParseResult(
|
||||||
calls=[
|
calls=[
|
||||||
ToolCallItem(
|
ToolCallItem(
|
||||||
tool_index=self.current_tool_id,
|
tool_index=tool_index_to_use,
|
||||||
parameters=argument_diff,
|
parameters=argument_diff,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if not is_complete[self.current_tool_id]:
|
if not is_current_complete:
|
||||||
self.streamed_args_for_tool[
|
self.streamed_args_for_tool[
|
||||||
self.current_tool_id
|
self.current_tool_id
|
||||||
] += argument_diff
|
] += argument_diff
|
||||||
|
|
||||||
self.prev_tool_call_arr = tool_call_arr
|
# Update prev_tool_call_arr with current state
|
||||||
|
if self.current_tool_id >= 0:
|
||||||
|
# Ensure prev_tool_call_arr is large enough
|
||||||
|
while len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||||
|
self.prev_tool_call_arr.append({})
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot_token = "<|python_tag|>"
|
self.bot_token = "<|python_tag|>"
|
||||||
|
# NOTE: technically Llama3.2 doesn't support well with parallel tool calls
|
||||||
|
# They need specific prompt engineering to support parallel tool calls
|
||||||
|
# Here we use ';' as the separator, which might have compatibility issues
|
||||||
|
# if users define to use a different separator in their prompt
|
||||||
|
self.tool_call_separator = ";"
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
"""Check if the text contains a Llama 3.2 format tool call."""
|
"""Check if the text contains a Llama 3.2 format tool call."""
|
||||||
@@ -42,7 +47,11 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
normal_text, action_text = "", text
|
normal_text, action_text = "", text
|
||||||
|
|
||||||
# Split by semicolon and process each part
|
# Split by semicolon and process each part
|
||||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
json_parts = [
|
||||||
|
part.strip()
|
||||||
|
for part in action_text.split(self.tool_call_separator)
|
||||||
|
if part.strip()
|
||||||
|
]
|
||||||
all_actions = []
|
all_actions = []
|
||||||
for part in json_parts:
|
for part in json_parts:
|
||||||
try:
|
try:
|
||||||
@@ -70,5 +79,5 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
return EBNFComposer.build_ebnf(
|
return EBNFComposer.build_ebnf(
|
||||||
tools,
|
tools,
|
||||||
function_format="json",
|
function_format="json",
|
||||||
tool_call_separator=",",
|
tool_call_separator=self.tool_call_separator,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
self.bot_token = "[TOOL_CALLS] ["
|
self.bot_token = "[TOOL_CALLS] ["
|
||||||
self.eot_token = "]"
|
self.eot_token = "]"
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
|
self.tool_call_separator = ", "
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
"""Check if the text contains a Mistral format tool call."""
|
"""Check if the text contains a Mistral format tool call."""
|
||||||
@@ -126,5 +127,5 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
sequence_start_token=self.bot_token,
|
sequence_start_token=self.bot_token,
|
||||||
sequence_end_token=self.eot_token,
|
sequence_end_token=self.eot_token,
|
||||||
function_format="json",
|
function_format="json",
|
||||||
tool_call_separator=", ",
|
tool_call_separator=self.tool_call_separator,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.bot_token = "<tool_call>\n"
|
self.bot_token = "<tool_call>\n"
|
||||||
self.eot_token = "\n</tool_call>"
|
self.eot_token = "\n</tool_call>"
|
||||||
|
self.tool_call_separator = "\n"
|
||||||
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
self._normal_text_buffer = "" # Buffer for handling partial end tokens
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
@@ -104,7 +105,6 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def structure_info(self) -> _GetInfoFunc:
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
# TODO: Update the begin and end tokens with '\n' if necessary
|
|
||||||
return lambda name: StructureInfo(
|
return lambda name: StructureInfo(
|
||||||
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
|
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
|
||||||
end="}\n</tool_call>",
|
end="}\n</tool_call>",
|
||||||
|
|||||||
@@ -18,6 +18,23 @@ def _find_common_prefix(s1: str, s2: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||||
|
"""
|
||||||
|
Parse incomplete or partial JSON strings commonly encountered during streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_str (str): The potentially incomplete JSON string to parse.
|
||||||
|
flags (Allow): Bitwise flags controlling what types of partial data are allowed.
|
||||||
|
Common flags include:
|
||||||
|
- Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
|
||||||
|
- Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
|
||||||
|
- Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
|
||||||
|
- Allow.ALL: Allow all types of partial data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, int]: A tuple containing:
|
||||||
|
- parsed_object: The Python object parsed from the JSON
|
||||||
|
- consumed_length: Number of characters consumed from input_str
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
|
|||||||
@@ -1327,7 +1327,6 @@ def v1_chat_generate_response(
|
|||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||||
index=call_info.tool_index,
|
|
||||||
function=FunctionResponse(
|
function=FunctionResponse(
|
||||||
name=call_info.name, arguments=call_info.parameters
|
name=call_info.name, arguments=call_info.parameters
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import unittest
|
|||||||
|
|
||||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||||
|
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
@@ -516,5 +517,237 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
self.fail(f"Failed to compile EBNF: {e}")
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseFormatDetector(unittest.TestCase):
|
||||||
|
"""Test buffer management and sequential tool index assignment in BaseFormatDetector."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test detector and tools."""
|
||||||
|
|
||||||
|
# Create a concrete implementation of BaseFormatDetector for testing
|
||||||
|
class TestFormatDetector(BaseFormatDetector):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bot_token = "<tool_call>"
|
||||||
|
self.eot_token = "</tool_call>"
|
||||||
|
|
||||||
|
def detect_and_parse(self, text, tools):
|
||||||
|
# Not used in streaming tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
def has_tool_call(self, text):
|
||||||
|
return "<tool_call>" in text
|
||||||
|
|
||||||
|
def structure_info(self):
|
||||||
|
# Not used in streaming tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_ebnf(self, tools):
|
||||||
|
# Not used in streaming tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.detector = TestFormatDetector()
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_tourist_attractions",
|
||||||
|
description="Get tourist attractions",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_sequential_tool_index_assignment(self):
|
||||||
|
"""Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...)."""
|
||||||
|
# Simulate streaming chunks for two consecutive tool calls
|
||||||
|
chunks = [
|
||||||
|
"<tool_call>",
|
||||||
|
'{"name": "get_weather", ',
|
||||||
|
'"arguments": {"city": "Paris"}}',
|
||||||
|
", ",
|
||||||
|
'{"name": "get_tourist_attractions", ',
|
||||||
|
'"arguments": {"city": "London"}}',
|
||||||
|
"</tool_call>",
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_indices_seen = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||||
|
|
||||||
|
if result.calls:
|
||||||
|
for call in result.calls:
|
||||||
|
if call.tool_index is not None:
|
||||||
|
tool_indices_seen.append(call.tool_index)
|
||||||
|
|
||||||
|
# Verify we got sequential tool indices
|
||||||
|
unique_indices = sorted(set(tool_indices_seen))
|
||||||
|
self.assertEqual(
|
||||||
|
unique_indices,
|
||||||
|
[0, 1],
|
||||||
|
f"Expected sequential tool indices [0, 1], got {unique_indices}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_buffer_content_preservation(self):
|
||||||
|
"""Test that buffer correctly preserves unprocessed content when tool completes."""
|
||||||
|
# Test simpler scenario: tool completion followed by new tool start
|
||||||
|
chunks = [
|
||||||
|
"<tool_call>",
|
||||||
|
'{"name": "get_weather", ',
|
||||||
|
'"arguments": {"city": "Paris"}}',
|
||||||
|
", ",
|
||||||
|
'{"name": "get_tourist_attractions", ',
|
||||||
|
'"arguments": {"city": "London"}} </tool_call>',
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_seen = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||||
|
if result.calls:
|
||||||
|
for call in result.calls:
|
||||||
|
if (
|
||||||
|
call.name
|
||||||
|
): # Only count calls with names (not just parameter updates)
|
||||||
|
tool_calls_seen.append(call.name)
|
||||||
|
|
||||||
|
# Should see both tool names
|
||||||
|
self.assertIn("get_weather", tool_calls_seen, "Should process first tool")
|
||||||
|
self.assertIn(
|
||||||
|
"get_tourist_attractions", tool_calls_seen, "Should process second tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_current_tool_id_increment_on_completion(self):
|
||||||
|
"""Test that current_tool_id increments when a tool completes."""
|
||||||
|
# Initial state
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector.current_tool_id, -1, "Should start with current_tool_id=-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process first tool completely
|
||||||
|
chunks = [
|
||||||
|
"<tool_call>",
|
||||||
|
'{"name": "get_weather", ',
|
||||||
|
]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector.current_tool_id, 0, "current_tool_id should be 0"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result.calls[0].name, "get_weather", "The first tool should be get_weather"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result.calls[0].tool_index, 0, "The first tool index should be 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Complete second tool name - this should show that current_tool_id is now 1
|
||||||
|
result = self.detector.parse_streaming_increment(
|
||||||
|
'"arguments": {"city": "Paris"}}, {"name": "get_', self.tools
|
||||||
|
)
|
||||||
|
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector.current_tool_id,
|
||||||
|
1,
|
||||||
|
"current_tool_id should be 1 after first tool completes and second tool starts",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.detector.parse_streaming_increment(
|
||||||
|
'tourist_attractions", ', self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second tool should have tool_index=1
|
||||||
|
tourist_calls = [
|
||||||
|
call for call in result.calls if call.name == "get_tourist_attractions"
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tool_name_streaming_with_correct_index(self):
|
||||||
|
"""Test that tool names are streamed with correct tool_index values."""
|
||||||
|
# Process first tool
|
||||||
|
self.detector.parse_streaming_increment("<tool_call>", self.tools)
|
||||||
|
result1 = self.detector.parse_streaming_increment(
|
||||||
|
'{"name": "get_weather", ', self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# First tool name should have tool_index=0
|
||||||
|
weather_calls = [call for call in result1.calls if call.name == "get_weather"]
|
||||||
|
self.assertEqual(len(weather_calls), 1, "Should have one weather call")
|
||||||
|
self.assertEqual(
|
||||||
|
weather_calls[0].tool_index, 0, "First tool should have tool_index=0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Complete first tool
|
||||||
|
self.detector.parse_streaming_increment(
|
||||||
|
'"arguments": {"city": "Paris"}}', self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start second tool
|
||||||
|
self.detector.parse_streaming_increment(", ", self.tools)
|
||||||
|
result2 = self.detector.parse_streaming_increment(
|
||||||
|
'{"name": "get_tourist_attractions", ', self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second tool name should have tool_index=1
|
||||||
|
tourist_calls = [
|
||||||
|
call for call in result2.calls if call.name == "get_tourist_attractions"
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
len(tourist_calls), 1, "Should have one tourist attractions call"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_buffer_reset_on_invalid_tool(self):
|
||||||
|
"""Test that buffer and state are reset when an invalid tool name is encountered."""
|
||||||
|
# Start fresh with an invalid tool name from the beginning
|
||||||
|
result = self.detector.parse_streaming_increment(
|
||||||
|
'<tool_call>{"name": "invalid_tool", ', self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty result and reset state
|
||||||
|
self.assertEqual(result.calls, [], "Should return no calls for invalid tool")
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector.current_tool_id,
|
||||||
|
-1,
|
||||||
|
"current_tool_id should remain -1 for invalid tool",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
self.detector._buffer, "", "Buffer should be cleared for invalid tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user