Use jsonschema to constrain required or specific tool choice (#10550)
This commit is contained in:
@@ -5,8 +5,10 @@ from xgrammar import GrammarCompiler, TokenizerInfo
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import StreamingParseResult
|
||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
||||
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
||||
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||
@@ -2190,5 +2192,322 @@ class TestGlm4MoeDetector(unittest.TestCase):
|
||||
self.assertEqual(self.detector._buffer, "")
|
||||
|
||||
|
||||
class TestJsonArrayParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create sample tools for testing
|
||||
self.tools = [
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_weather",
|
||||
description="Get weather information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Location to get weather for",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
),
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="search",
|
||||
description="Search for information",
|
||||
parameters={
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
self.detector = JsonArrayParser()
|
||||
|
||||
def test_json_detector_ebnf(self):
|
||||
"""Test that the JsonArrayParser returns NotImplementedError for EBNF."""
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
self.detector.build_ebnf(self.tools)
|
||||
self.assertIn(
|
||||
"EBNF generation is not supported for JSON schema constraints",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_parse_streaming_increment_malformed_json(self):
|
||||
"""Test parsing with malformed JSON"""
|
||||
# Test with malformed JSON
|
||||
text = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
# Should not crash and return a valid result
|
||||
self.assertIsInstance(result, StreamingParseResult)
|
||||
|
||||
text = "[{}}}]"
|
||||
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||
|
||||
self.assertIsInstance(result, StreamingParseResult)
|
||||
|
||||
def test_parse_streaming_increment_empty_input(self):
|
||||
"""Test parsing with empty input"""
|
||||
result = self.detector.parse_streaming_increment("", self.tools)
|
||||
self.assertEqual(len(result.calls), 0)
|
||||
self.assertEqual(result.normal_text, "")
|
||||
|
||||
def test_parse_streaming_increment_whitespace_handling(self):
|
||||
"""Test parsing with various whitespace scenarios"""
|
||||
# Test with leading/trailing whitespace split across chunks
|
||||
chunk1 = ' [{"name": "get_weather", "parameters": '
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = '{"location": "Tokyo"}}] '
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
|
||||
# The base class should handle this
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
def test_parse_streaming_increment_nested_objects(self):
|
||||
"""Test parsing with nested JSON objects"""
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo", '
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = '"nested": {"key": "value"}}}]'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
|
||||
# The base class should handle this
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
def test_json_parsing_with_commas(self):
|
||||
"""Test that JSON parsing works correctly with comma separators"""
|
||||
# Stream two complete objects, at least 2 chunks per tool call
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = 'yo"}},'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
chunk3 = '{"name": "get_weather", "parameters": {"location": "Par'
|
||||
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||
self.assertIsInstance(result3, StreamingParseResult)
|
||||
chunk4 = 'is"}}]'
|
||||
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||
self.assertIsInstance(result4, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result4.calls), 0, "Should parse tool calls from text with separators"
|
||||
)
|
||||
|
||||
def test_braces_in_strings(self):
|
||||
"""Test that JSON with } characters inside strings works correctly"""
|
||||
# Test case: JSON array with } inside string values - streamed across chunks
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "has } inside"'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = "}}"
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result2.calls), 0, "Should parse tool call with } in string"
|
||||
)
|
||||
|
||||
# Test with separator (streaming in progress)
|
||||
chunk3 = '[{"name": "get_weather", "parameters": {"location": "has } inside"}'
|
||||
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||
self.assertIsInstance(result3, StreamingParseResult)
|
||||
chunk4 = "},"
|
||||
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||
self.assertIsInstance(result4, StreamingParseResult)
|
||||
chunk5 = '{"name": "get_weather"'
|
||||
result5 = self.detector.parse_streaming_increment(chunk5, self.tools)
|
||||
self.assertIsInstance(result5, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result5.calls),
|
||||
0,
|
||||
"Should parse tool calls with separator and } in string",
|
||||
)
|
||||
|
||||
def test_separator_in_same_chunk(self):
|
||||
"""Test that separator already present in chunk works correctly"""
|
||||
# Test case: separator already in the chunk (streaming in progress) with 2+ chunks per tool call
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = '}},{"name": "get_weather"'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result2.calls),
|
||||
0,
|
||||
"Should parse tool calls with separator in same chunk",
|
||||
)
|
||||
|
||||
def test_separator_in_separate_chunk(self):
|
||||
"""Test that separator in separate chunk works correctly"""
|
||||
# Test case: separator in separate chunk - this tests streaming behavior
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}}'
|
||||
chunk2 = ","
|
||||
chunk3 = '{"name": "get_weather", "parameters": {"location": "Paris"}}'
|
||||
|
||||
# Process first chunk
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
|
||||
# Process separator chunk
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
# Process second chunk (streaming in progress)
|
||||
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||
self.assertIsInstance(result3, StreamingParseResult)
|
||||
|
||||
def test_incomplete_json_across_chunks(self):
|
||||
"""Test that incomplete JSON across chunks works correctly"""
|
||||
# Test case: incomplete JSON across chunks - this tests streaming behavior
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||
chunk2 = '}},{"name": "get_weather"'
|
||||
|
||||
# Process first chunk (incomplete)
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
|
||||
# Process second chunk (completes first object and starts second, streaming in progress)
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
def test_malformed_json_recovery(self):
|
||||
"""Test that malformed JSON recovers gracefully"""
|
||||
# Test with malformed JSON - should handle gracefully
|
||||
malformed_text = (
|
||||
'[{"name": "get_weather", "parameters": {"location": "unclosed string'
|
||||
)
|
||||
|
||||
result1 = self.detector.parse_streaming_increment(malformed_text, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
|
||||
# Test valid JSON after malformed - streamed across 2 chunks (streaming in progress)
|
||||
valid_chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||
result2 = self.detector.parse_streaming_increment(valid_chunk1, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
valid_chunk2 = 'yo"}}'
|
||||
result3 = self.detector.parse_streaming_increment(valid_chunk2, self.tools)
|
||||
self.assertIsInstance(result3, StreamingParseResult)
|
||||
|
||||
def test_nested_objects_with_commas(self):
|
||||
"""Test that nested objects with commas inside work correctly"""
|
||||
# Test with nested objects that have commas - should work with json.loads()
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = 'yo", "unit": "celsius"}}'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result2.calls), 0, "Should parse tool call with nested objects"
|
||||
)
|
||||
|
||||
def test_empty_objects(self):
|
||||
"""Test that empty objects work correctly"""
|
||||
# Test with empty objects - should work with json.loads()
|
||||
chunk1 = '[{"name": "get_weather", "parameters": '
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = "{}}"
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test that various whitespace scenarios work correctly"""
|
||||
# Test with various whitespace patterns - should work with json.loads()
|
||||
chunk1 = ' \n\n [{"name": "get_weather", "parameters": '
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = '{"location": "Tokyo"}}'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
def test_multiple_commas_in_chunk(self):
|
||||
"""Test that multiple commas in a single chunk work correctly"""
|
||||
# Stream multiple tool calls ensuring at least 2 chunks per complete tool call
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "To'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = 'kyo"}},'
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
|
||||
chunk3 = '{"name": "get_weather", "parameters": {"location": "Pa'
|
||||
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||
self.assertIsInstance(result3, StreamingParseResult)
|
||||
chunk4 = 'ris"}},'
|
||||
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||
self.assertIsInstance(result4, StreamingParseResult)
|
||||
|
||||
chunk5 = '{"name": "get_weather"'
|
||||
result5 = self.detector.parse_streaming_increment(chunk5, self.tools)
|
||||
self.assertIsInstance(result5, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result5.calls), 0, "Should parse tool calls with multiple commas"
|
||||
)
|
||||
|
||||
def test_complete_tool_call_with_trailing_comma(self):
|
||||
"""Test that complete tool call with trailing comma parses correctly"""
|
||||
# Test case: complete tool call followed by comma at end of chunk (split across 2 chunks)
|
||||
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}'
|
||||
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||
self.assertIsInstance(result1, StreamingParseResult)
|
||||
chunk2 = "}, "
|
||||
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||
self.assertIsInstance(result2, StreamingParseResult)
|
||||
self.assertGreater(len(result2.calls), 0, "Should parse complete tool call")
|
||||
|
||||
# Test that next chunk with opening brace gets the separator prepended
|
||||
next_chunk = '{"name": "get_weather", "parameters": {"location": "Paris"}}'
|
||||
result_next = self.detector.parse_streaming_increment(next_chunk, self.tools)
|
||||
self.assertIsInstance(result_next, StreamingParseResult)
|
||||
self.assertGreater(
|
||||
len(result_next.calls), 0, "Should parse subsequent tool call"
|
||||
)
|
||||
|
||||
def test_three_tool_calls_separate_chunks_with_commas(self):
|
||||
"""Test parsing 3 tool calls in separate chunks with commas at the end"""
|
||||
# First tool call: 2 chunks
|
||||
chunk1_1 = '[{"name": "get_weather", "parameters": '
|
||||
result1_1 = self.detector.parse_streaming_increment(chunk1_1, self.tools)
|
||||
chunk1_2 = '{"location": "Tokyo"}},'
|
||||
result1_2 = self.detector.parse_streaming_increment(chunk1_2, self.tools)
|
||||
self.assertIsInstance(result1_2, StreamingParseResult)
|
||||
self.assertGreater(len(result1_2.calls), 0, "Should parse first tool call")
|
||||
|
||||
# Second tool call: 2 chunks
|
||||
chunk2_1 = '{"name": "search", "parameters": '
|
||||
result2_1 = self.detector.parse_streaming_increment(chunk2_1, self.tools)
|
||||
chunk2_2 = '{"query": "restaurants"}},'
|
||||
result2_2 = self.detector.parse_streaming_increment(chunk2_2, self.tools)
|
||||
self.assertIsInstance(result2_2, StreamingParseResult)
|
||||
self.assertGreater(len(result2_2.calls), 0, "Should parse second tool call")
|
||||
|
||||
# Third tool call: 2 chunks
|
||||
chunk3_1 = '{"name": "get_weather", "parameters": '
|
||||
result3_1 = self.detector.parse_streaming_increment(chunk3_1, self.tools)
|
||||
chunk3_2 = '{"location": "Paris"}}]'
|
||||
result3_2 = self.detector.parse_streaming_increment(chunk3_2, self.tools)
|
||||
self.assertIsInstance(result3_2, StreamingParseResult)
|
||||
self.assertGreater(len(result3_2.calls), 0, "Should parse third tool call")
|
||||
# Verify all tool calls were parsed correctly
|
||||
total_calls = len(result1_2.calls) + len(result2_2.calls) + len(result3_2.calls)
|
||||
self.assertEqual(total_calls, 3, "Should have parsed exactly 3 tool calls")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user