refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor parse_streaming_increment (#6715)
This commit is contained in:
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
from xgrammar import GrammarCompiler, TokenizerInfo
|
||||
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
@@ -516,5 +517,237 @@ class TestEBNFGeneration(unittest.TestCase):
|
||||
self.fail(f"Failed to compile EBNF: {e}")
|
||||
|
||||
|
||||
class TestBaseFormatDetector(unittest.TestCase):
|
||||
"""Test buffer management and sequential tool index assignment in BaseFormatDetector."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test detector and tools."""
|
||||
|
||||
# Create a concrete implementation of BaseFormatDetector for testing
|
||||
class TestFormatDetector(BaseFormatDetector):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
|
||||
def detect_and_parse(self, text, tools):
|
||||
# Not used in streaming tests
|
||||
pass
|
||||
|
||||
def has_tool_call(self, text):
|
||||
return "<tool_call>" in text
|
||||
|
||||
def structure_info(self):
|
||||
# Not used in streaming tests
|
||||
pass
|
||||
|
||||
def build_ebnf(self, tools):
|
||||
# Not used in streaming tests
|
||||
pass
|
||||
|
||||
self.detector = TestFormatDetector()
|
||||
self.tools = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
),
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_tourist_attractions",
|
||||
description="Get tourist attractions",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def test_sequential_tool_index_assignment(self):
|
||||
"""Test that multiple tool calls get sequential tool_index values (0, 1, 2, ...)."""
|
||||
# Simulate streaming chunks for two consecutive tool calls
|
||||
chunks = [
|
||||
"<tool_call>",
|
||||
'{"name": "get_weather", ',
|
||||
'"arguments": {"city": "Paris"}}',
|
||||
", ",
|
||||
'{"name": "get_tourist_attractions", ',
|
||||
'"arguments": {"city": "London"}}',
|
||||
"</tool_call>",
|
||||
]
|
||||
|
||||
tool_indices_seen = []
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
|
||||
if result.calls:
|
||||
for call in result.calls:
|
||||
if call.tool_index is not None:
|
||||
tool_indices_seen.append(call.tool_index)
|
||||
|
||||
# Verify we got sequential tool indices
|
||||
unique_indices = sorted(set(tool_indices_seen))
|
||||
self.assertEqual(
|
||||
unique_indices,
|
||||
[0, 1],
|
||||
f"Expected sequential tool indices [0, 1], got {unique_indices}",
|
||||
)
|
||||
|
||||
def test_buffer_content_preservation(self):
|
||||
"""Test that buffer correctly preserves unprocessed content when tool completes."""
|
||||
# Test simpler scenario: tool completion followed by new tool start
|
||||
chunks = [
|
||||
"<tool_call>",
|
||||
'{"name": "get_weather", ',
|
||||
'"arguments": {"city": "Paris"}}',
|
||||
", ",
|
||||
'{"name": "get_tourist_attractions", ',
|
||||
'"arguments": {"city": "London"}} </tool_call>',
|
||||
]
|
||||
|
||||
tool_calls_seen = []
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
if result.calls:
|
||||
for call in result.calls:
|
||||
if (
|
||||
call.name
|
||||
): # Only count calls with names (not just parameter updates)
|
||||
tool_calls_seen.append(call.name)
|
||||
|
||||
# Should see both tool names
|
||||
self.assertIn("get_weather", tool_calls_seen, "Should process first tool")
|
||||
self.assertIn(
|
||||
"get_tourist_attractions", tool_calls_seen, "Should process second tool"
|
||||
)
|
||||
|
||||
def test_current_tool_id_increment_on_completion(self):
|
||||
"""Test that current_tool_id increments when a tool completes."""
|
||||
# Initial state
|
||||
self.assertEqual(
|
||||
self.detector.current_tool_id, -1, "Should start with current_tool_id=-1"
|
||||
)
|
||||
|
||||
# Process first tool completely
|
||||
chunks = [
|
||||
"<tool_call>",
|
||||
'{"name": "get_weather", ',
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
result = self.detector.parse_streaming_increment(chunk, self.tools)
|
||||
|
||||
self.assertEqual(
|
||||
self.detector.current_tool_id, 0, "current_tool_id should be 0"
|
||||
)
|
||||
self.assertEqual(
|
||||
result.calls[0].name, "get_weather", "The first tool should be get_weather"
|
||||
)
|
||||
self.assertEqual(
|
||||
result.calls[0].tool_index, 0, "The first tool index should be 0"
|
||||
)
|
||||
|
||||
# Complete second tool name - this should show that current_tool_id is now 1
|
||||
result = self.detector.parse_streaming_increment(
|
||||
'"arguments": {"city": "Paris"}}, {"name": "get_', self.tools
|
||||
)
|
||||
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
|
||||
|
||||
self.assertEqual(
|
||||
self.detector.current_tool_id,
|
||||
1,
|
||||
"current_tool_id should be 1 after first tool completes and second tool starts",
|
||||
)
|
||||
|
||||
result = self.detector.parse_streaming_increment(
|
||||
'tourist_attractions", ', self.tools
|
||||
)
|
||||
|
||||
# Second tool should have tool_index=1
|
||||
tourist_calls = [
|
||||
call for call in result.calls if call.name == "get_tourist_attractions"
|
||||
]
|
||||
self.assertEqual(
|
||||
tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
|
||||
)
|
||||
|
||||
def test_tool_name_streaming_with_correct_index(self):
|
||||
"""Test that tool names are streamed with correct tool_index values."""
|
||||
# Process first tool
|
||||
self.detector.parse_streaming_increment("<tool_call>", self.tools)
|
||||
result1 = self.detector.parse_streaming_increment(
|
||||
'{"name": "get_weather", ', self.tools
|
||||
)
|
||||
|
||||
# First tool name should have tool_index=0
|
||||
weather_calls = [call for call in result1.calls if call.name == "get_weather"]
|
||||
self.assertEqual(len(weather_calls), 1, "Should have one weather call")
|
||||
self.assertEqual(
|
||||
weather_calls[0].tool_index, 0, "First tool should have tool_index=0"
|
||||
)
|
||||
|
||||
# Complete first tool
|
||||
self.detector.parse_streaming_increment(
|
||||
'"arguments": {"city": "Paris"}}', self.tools
|
||||
)
|
||||
|
||||
# Start second tool
|
||||
self.detector.parse_streaming_increment(", ", self.tools)
|
||||
result2 = self.detector.parse_streaming_increment(
|
||||
'{"name": "get_tourist_attractions", ', self.tools
|
||||
)
|
||||
|
||||
# Second tool name should have tool_index=1
|
||||
tourist_calls = [
|
||||
call for call in result2.calls if call.name == "get_tourist_attractions"
|
||||
]
|
||||
self.assertEqual(
|
||||
len(tourist_calls), 1, "Should have one tourist attractions call"
|
||||
)
|
||||
self.assertEqual(
|
||||
tourist_calls[0].tool_index, 1, "Second tool should have tool_index=1"
|
||||
)
|
||||
|
||||
def test_buffer_reset_on_invalid_tool(self):
|
||||
"""Test that buffer and state are reset when an invalid tool name is encountered."""
|
||||
# Start fresh with an invalid tool name from the beginning
|
||||
result = self.detector.parse_streaming_increment(
|
||||
'<tool_call>{"name": "invalid_tool", ', self.tools
|
||||
)
|
||||
|
||||
# Should return empty result and reset state
|
||||
self.assertEqual(result.calls, [], "Should return no calls for invalid tool")
|
||||
self.assertEqual(
|
||||
self.detector.current_tool_id,
|
||||
-1,
|
||||
"current_tool_id should remain -1 for invalid tool",
|
||||
)
|
||||
self.assertEqual(
|
||||
self.detector._buffer, "", "Buffer should be cleared for invalid tool"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user