409 lines
17 KiB
Python
409 lines
17 KiB
Python
|
|
import json
|
|||
|
|
import unittest
|
|||
|
|
|
|||
|
|
from xgrammar import GrammarCompiler, TokenizerInfo
|
|||
|
|
|
|||
|
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
|||
|
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
|||
|
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
|||
|
|
from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
|||
|
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
|||
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|||
|
|
from sglang.srt.openai_api.protocol import Function, Tool
|
|||
|
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestPythonicDetector(unittest.TestCase):
|
|||
|
|
def setUp(self):
|
|||
|
|
# Create sample tools for testing
|
|||
|
|
self.tools = [
|
|||
|
|
Tool(
|
|||
|
|
type="function",
|
|||
|
|
function=Function(
|
|||
|
|
name="get_weather",
|
|||
|
|
description="Get weather information",
|
|||
|
|
parameters={
|
|||
|
|
"properties": {
|
|||
|
|
"location": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Location to get weather for",
|
|||
|
|
},
|
|||
|
|
"unit": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Temperature unit",
|
|||
|
|
"enum": ["celsius", "fahrenheit"],
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
"required": ["location"],
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
),
|
|||
|
|
Tool(
|
|||
|
|
type="function",
|
|||
|
|
function=Function(
|
|||
|
|
name="search",
|
|||
|
|
description="Search for information",
|
|||
|
|
parameters={
|
|||
|
|
"properties": {
|
|||
|
|
"query": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Search query",
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
"required": ["query"],
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
self.detector = PythonicDetector()
|
|||
|
|
|
|||
|
|
def test_parse_streaming_no_brackets(self):
|
|||
|
|
"""Test parsing text with no brackets (no tool calls)."""
|
|||
|
|
text = "This is just normal text without any tool calls."
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, text)
|
|||
|
|
self.assertEqual(result.calls, [])
|
|||
|
|
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
|||
|
|
|
|||
|
|
def test_parse_streaming_complete_tool_call(self):
|
|||
|
|
"""Test parsing a complete tool call."""
|
|||
|
|
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "Here's a tool call: ")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, ""
|
|||
|
|
) # Buffer should be cleared after processing
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["location"], "New York")
|
|||
|
|
self.assertEqual(params["unit"], "celsius")
|
|||
|
|
|
|||
|
|
def test_parse_streaming_text_before_tool_call(self):
|
|||
|
|
"""Test parsing text that appears before a tool call."""
|
|||
|
|
text = "This is some text before [get_weather(location='London')]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "This is some text before ")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "get_weather")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["location"], "London")
|
|||
|
|
|
|||
|
|
def test_parse_streaming_partial_tool_call(self):
|
|||
|
|
"""Test parsing a partial tool call that spans multiple chunks."""
|
|||
|
|
# First chunk with opening bracket but no closing bracket
|
|||
|
|
text1 = "Let me check the weather: [get_weather(location="
|
|||
|
|
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result1.normal_text, "Let me check the weather: ")
|
|||
|
|
self.assertEqual(result1.calls, [])
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, "[get_weather(location="
|
|||
|
|
) # Partial tool call remains in buffer
|
|||
|
|
|
|||
|
|
# Second chunk completing the tool call
|
|||
|
|
text2 = "'Paris')]"
|
|||
|
|
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result2.normal_text, "")
|
|||
|
|
self.assertEqual(len(result2.calls), 1)
|
|||
|
|
self.assertEqual(result2.calls[0].name, "get_weather")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result2.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["location"], "Paris")
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, ""
|
|||
|
|
) # Buffer should be cleared after processing
|
|||
|
|
|
|||
|
|
def test_parse_streaming_bracket_without_text_before(self):
|
|||
|
|
"""Test parsing a tool call that starts at the beginning of the text."""
|
|||
|
|
text = "[search(query='python programming')]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "search")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["query"], "python programming")
|
|||
|
|
|
|||
|
|
def test_parse_streaming_text_after_tool_call(self):
|
|||
|
|
"""Test parsing text that appears after a tool call."""
|
|||
|
|
# First chunk with complete tool call and some text after
|
|||
|
|
text = "[get_weather(location='Tokyo')] Here's the forecast:"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, " Here's the forecast:"
|
|||
|
|
) # Text after tool call remains in buffer
|
|||
|
|
|
|||
|
|
# Process the remaining text in buffer
|
|||
|
|
result2 = self.detector.parse_streaming_increment("", self.tools)
|
|||
|
|
self.assertEqual(result2.normal_text, " Here's the forecast:")
|
|||
|
|
self.assertEqual(result2.calls, [])
|
|||
|
|
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
|
|||
|
|
|
|||
|
|
def test_parse_streaming_multiple_tool_calls(self):
|
|||
|
|
"""Test parsing multiple tool calls in sequence."""
|
|||
|
|
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
|
|||
|
|
|
|||
|
|
# First tool call
|
|||
|
|
result1 = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
self.assertEqual(len(result1.calls), 1)
|
|||
|
|
self.assertEqual(result1.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
|
|||
|
|
|
|||
|
|
# Second tool call
|
|||
|
|
result2 = self.detector.parse_streaming_increment("", self.tools)
|
|||
|
|
self.assertEqual(result2.normal_text, " and ")
|
|||
|
|
self.assertEqual(len(result2.calls), 1)
|
|||
|
|
self.assertEqual(result2.calls[0].name, "search")
|
|||
|
|
self.assertEqual(self.detector._buffer, "")
|
|||
|
|
|
|||
|
|
def test_parse_streaming_opening_bracket_only(self):
|
|||
|
|
"""Test parsing text with only an opening bracket but no closing bracket."""
|
|||
|
|
text = "Let's try this: ["
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "Let's try this: ")
|
|||
|
|
self.assertEqual(result.calls, [])
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, "["
|
|||
|
|
) # Opening bracket remains in buffer
|
|||
|
|
|
|||
|
|
def test_parse_streaming_nested_brackets(self):
|
|||
|
|
"""Test parsing tool calls with nested brackets in arguments."""
|
|||
|
|
# Test with list argument containing nested brackets
|
|||
|
|
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(self.detector._buffer, "")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["location"], "New York")
|
|||
|
|
self.assertEqual(params["unit"], "celsius")
|
|||
|
|
self.assertEqual(params["data"], [1, 2, 3])
|
|||
|
|
|
|||
|
|
def test_parse_streaming_nested_brackets_dict(self):
|
|||
|
|
"""Test parsing tool calls with nested dictionaries and lists."""
|
|||
|
|
# Test with nested dict and list arguments
|
|||
|
|
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "")
|
|||
|
|
self.assertEqual(len(result.calls), 1)
|
|||
|
|
self.assertEqual(result.calls[0].name, "search")
|
|||
|
|
self.assertEqual(self.detector._buffer, "")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["query"], "test")
|
|||
|
|
self.assertEqual(params["config"]["options"], [1, 2])
|
|||
|
|
self.assertEqual(params["config"]["nested"]["key"], "value")
|
|||
|
|
|
|||
|
|
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
|
|||
|
|
"""Test parsing multiple tool calls with nested brackets."""
|
|||
|
|
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
|
|||
|
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result.normal_text, "")
|
|||
|
|
self.assertEqual(len(result.calls), 2)
|
|||
|
|
self.assertEqual(self.detector._buffer, "")
|
|||
|
|
|
|||
|
|
# Check first tool call
|
|||
|
|
params1 = json.loads(result.calls[0].parameters)
|
|||
|
|
self.assertEqual(result.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(params1["location"], "Paris")
|
|||
|
|
self.assertEqual(params1["data"], [10, 20])
|
|||
|
|
|
|||
|
|
# Check second tool call
|
|||
|
|
params2 = json.loads(result.calls[1].parameters)
|
|||
|
|
self.assertEqual(result.calls[1].name, "search")
|
|||
|
|
self.assertEqual(params2["query"], "test")
|
|||
|
|
self.assertEqual(params2["filters"], ["a", "b"])
|
|||
|
|
|
|||
|
|
def test_parse_streaming_partial_nested_brackets(self):
|
|||
|
|
"""Test parsing partial tool calls with nested brackets across chunks."""
|
|||
|
|
# First chunk with nested brackets but incomplete
|
|||
|
|
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
|
|||
|
|
result1 = self.detector.parse_streaming_increment(text1, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result1.normal_text, "Here's a call: ")
|
|||
|
|
self.assertEqual(result1.calls, [])
|
|||
|
|
self.assertEqual(
|
|||
|
|
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Second chunk completing the nested brackets
|
|||
|
|
text2 = ", 3])]"
|
|||
|
|
result2 = self.detector.parse_streaming_increment(text2, self.tools)
|
|||
|
|
|
|||
|
|
self.assertEqual(result2.normal_text, "")
|
|||
|
|
self.assertEqual(len(result2.calls), 1)
|
|||
|
|
self.assertEqual(result2.calls[0].name, "get_weather")
|
|||
|
|
self.assertEqual(self.detector._buffer, "")
|
|||
|
|
|
|||
|
|
# Check the parameters
|
|||
|
|
params = json.loads(result2.calls[0].parameters)
|
|||
|
|
self.assertEqual(params["location"], "Tokyo")
|
|||
|
|
self.assertEqual(params["data"], [1, 2, 3])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestEBNFGeneration(unittest.TestCase):
|
|||
|
|
def setUp(self):
|
|||
|
|
# Create sample tools for testing
|
|||
|
|
self.tools = [
|
|||
|
|
Tool(
|
|||
|
|
type="function",
|
|||
|
|
function=Function(
|
|||
|
|
name="get_weather",
|
|||
|
|
description="Get weather information",
|
|||
|
|
parameters={
|
|||
|
|
"properties": {
|
|||
|
|
"location": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Location to get weather for",
|
|||
|
|
},
|
|||
|
|
"unit": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Temperature unit",
|
|||
|
|
"enum": ["celsius", "fahrenheit"],
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
"required": ["location"],
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
),
|
|||
|
|
Tool(
|
|||
|
|
type="function",
|
|||
|
|
function=Function(
|
|||
|
|
name="search",
|
|||
|
|
description="Search for information",
|
|||
|
|
parameters={
|
|||
|
|
"properties": {
|
|||
|
|
"query": {
|
|||
|
|
"type": "string",
|
|||
|
|
"description": "Search query",
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
"required": ["query"],
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|||
|
|
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
|
|||
|
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
|||
|
|
|
|||
|
|
# Initialize all detectors
|
|||
|
|
self.pythonic_detector = PythonicDetector()
|
|||
|
|
self.deepseekv3_detector = DeepSeekV3Detector()
|
|||
|
|
self.llama32_detector = Llama32Detector()
|
|||
|
|
self.mistral_detector = MistralDetector()
|
|||
|
|
self.qwen25_detector = Qwen25Detector()
|
|||
|
|
|
|||
|
|
def test_pythonic_detector_ebnf(self):
|
|||
|
|
"""Test that the PythonicDetector generates valid EBNF."""
|
|||
|
|
ebnf = self.pythonic_detector.build_ebnf(self.tools)
|
|||
|
|
self.assertIsNotNone(ebnf)
|
|||
|
|
|
|||
|
|
# Check that the EBNF contains expected patterns
|
|||
|
|
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
|
|||
|
|
self.assertIn('"location" "=" basic_string', ebnf)
|
|||
|
|
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
|
|||
|
|
|
|||
|
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
|||
|
|
try:
|
|||
|
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
|||
|
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
self.fail(f"Failed to compile EBNF: {e}")
|
|||
|
|
|
|||
|
|
def test_deepseekv3_detector_ebnf(self):
|
|||
|
|
"""Test that the DeepSeekV3Detector generates valid EBNF."""
|
|||
|
|
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
|
|||
|
|
self.assertIsNotNone(ebnf)
|
|||
|
|
|
|||
|
|
# Check that the EBNF contains expected patterns
|
|||
|
|
self.assertIn("<|tool▁calls▁begin|>", ebnf)
|
|||
|
|
self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
|
|||
|
|
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
|
|||
|
|
|
|||
|
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
|||
|
|
try:
|
|||
|
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
|||
|
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
self.fail(f"Failed to compile EBNF: {e}")
|
|||
|
|
|
|||
|
|
def test_llama32_detector_ebnf(self):
|
|||
|
|
"""Test that the Llama32Detector generates valid EBNF."""
|
|||
|
|
ebnf = self.llama32_detector.build_ebnf(self.tools)
|
|||
|
|
self.assertIsNotNone(ebnf)
|
|||
|
|
|
|||
|
|
# Check that the EBNF contains expected patterns
|
|||
|
|
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
|||
|
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
|||
|
|
|
|||
|
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
|||
|
|
try:
|
|||
|
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
|||
|
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
self.fail(f"Failed to compile EBNF: {e}")
|
|||
|
|
|
|||
|
|
def test_mistral_detector_ebnf(self):
|
|||
|
|
"""Test that the MistralDetector generates valid EBNF."""
|
|||
|
|
ebnf = self.mistral_detector.build_ebnf(self.tools)
|
|||
|
|
self.assertIsNotNone(ebnf)
|
|||
|
|
|
|||
|
|
# Check that the EBNF contains expected patterns
|
|||
|
|
self.assertIn('"[TOOL_CALLS] ["', ebnf)
|
|||
|
|
self.assertIn("call_get_weather | call_search", ebnf)
|
|||
|
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
|||
|
|
|
|||
|
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
|||
|
|
try:
|
|||
|
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
|||
|
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
self.fail(f"Failed to compile EBNF: {e}")
|
|||
|
|
|
|||
|
|
def test_qwen25_detector_ebnf(self):
|
|||
|
|
"""Test that the Qwen25Detector generates valid EBNF."""
|
|||
|
|
ebnf = self.qwen25_detector.build_ebnf(self.tools)
|
|||
|
|
self.assertIsNotNone(ebnf)
|
|||
|
|
|
|||
|
|
# Check that the EBNF contains expected patterns
|
|||
|
|
self.assertIn("<tool_call>", ebnf)
|
|||
|
|
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
|
|||
|
|
self.assertIn('"\\"arguments\\"" ":"', ebnf)
|
|||
|
|
|
|||
|
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
|||
|
|
try:
|
|||
|
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
|||
|
|
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
self.fail(f"Failed to compile EBNF: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
unittest.main()
|