feat(Tool Calling): Support required and specific function mode (#6550)

This commit is contained in:
Chang Su
2025-05-23 21:00:37 -07:00
committed by GitHub
parent e6f113569e
commit ed0c3035cd
17 changed files with 2022 additions and 883 deletions

View File

@@ -36,7 +36,7 @@ suites = {
TestFile("test_fa3.py", 376),
TestFile("test_fim_completion.py", 40),
TestFile("test_fp8_kernel.py", 8),
TestFile("test_function_calling.py", 60),
TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 254),
@@ -54,6 +54,7 @@ suites = {
TestFile("test_flashmla.py", 300),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 216),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),

View File

@@ -0,0 +1,408 @@
import json
import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestPythonicDetector(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.detector = PythonicDetector()
def test_parse_streaming_no_brackets(self):
"""Test parsing text with no brackets (no tool calls)."""
text = "This is just normal text without any tool calls."
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, text)
self.assertEqual(result.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_complete_tool_call(self):
"""Test parsing a complete tool call."""
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Here's a tool call: ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
def test_parse_streaming_text_before_tool_call(self):
"""Test parsing text that appears before a tool call."""
text = "This is some text before [get_weather(location='London')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "This is some text before ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "London")
def test_parse_streaming_partial_tool_call(self):
"""Test parsing a partial tool call that spans multiple chunks."""
# First chunk with opening bracket but no closing bracket
text1 = "Let me check the weather: [get_weather(location="
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Let me check the weather: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location="
) # Partial tool call remains in buffer
# Second chunk completing the tool call
text2 = "'Paris')]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Paris")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
def test_parse_streaming_bracket_without_text_before(self):
"""Test parsing a tool call that starts at the beginning of the text."""
text = "[search(query='python programming')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "python programming")
def test_parse_streaming_text_after_tool_call(self):
"""Test parsing text that appears after a tool call."""
# First chunk with complete tool call and some text after
text = "[get_weather(location='Tokyo')] Here's the forecast:"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, " Here's the forecast:"
) # Text after tool call remains in buffer
# Process the remaining text in buffer
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " Here's the forecast:")
self.assertEqual(result2.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_multiple_tool_calls(self):
"""Test parsing multiple tool calls in sequence."""
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
# First tool call
result1 = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(len(result1.calls), 1)
self.assertEqual(result1.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
# Second tool call
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " and ")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
def test_parse_streaming_opening_bracket_only(self):
"""Test parsing text with only an opening bracket but no closing bracket."""
text = "Let's try this: ["
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Let's try this: ")
self.assertEqual(result.calls, [])
self.assertEqual(
self.detector._buffer, "["
) # Opening bracket remains in buffer
def test_parse_streaming_nested_brackets(self):
"""Test parsing tool calls with nested brackets in arguments."""
# Test with list argument containing nested brackets
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
self.assertEqual(params["data"], [1, 2, 3])
def test_parse_streaming_nested_brackets_dict(self):
"""Test parsing tool calls with nested dictionaries and lists."""
# Test with nested dict and list arguments
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "test")
self.assertEqual(params["config"]["options"], [1, 2])
self.assertEqual(params["config"]["nested"]["key"], "value")
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
"""Test parsing multiple tool calls with nested brackets."""
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 2)
self.assertEqual(self.detector._buffer, "")
# Check first tool call
params1 = json.loads(result.calls[0].parameters)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(params1["location"], "Paris")
self.assertEqual(params1["data"], [10, 20])
# Check second tool call
params2 = json.loads(result.calls[1].parameters)
self.assertEqual(result.calls[1].name, "search")
self.assertEqual(params2["query"], "test")
self.assertEqual(params2["filters"], ["a", "b"])
def test_parse_streaming_partial_nested_brackets(self):
"""Test parsing partial tool calls with nested brackets across chunks."""
# First chunk with nested brackets but incomplete
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Here's a call: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
)
# Second chunk completing the nested brackets
text2 = ", 3])]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
class TestEBNFGeneration(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
# Initialize all detectors
self.pythonic_detector = PythonicDetector()
self.deepseekv3_detector = DeepSeekV3Detector()
self.llama32_detector = Llama32Detector()
self.mistral_detector = MistralDetector()
self.qwen25_detector = Qwen25Detector()
def test_pythonic_detector_ebnf(self):
"""Test that the PythonicDetector generates valid EBNF."""
ebnf = self.pythonic_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
self.assertIn('"location" "=" basic_string', ebnf)
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_deepseekv3_detector_ebnf(self):
"""Test that the DeepSeekV3Detector generates valid EBNF."""
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<tool▁calls▁begin>", ebnf)
self.assertIn("<tool▁call▁begin>function<tool▁sep>get_weather", ebnf)
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_llama32_detector_ebnf(self):
"""Test that the Llama32Detector generates valid EBNF."""
ebnf = self.llama32_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_mistral_detector_ebnf(self):
"""Test that the MistralDetector generates valid EBNF."""
ebnf = self.mistral_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('"[TOOL_CALLS] ["', ebnf)
self.assertIn("call_get_weather | call_search", ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_qwen25_detector_ebnf(self):
"""Test that the Qwen25Detector generates valid EBNF."""
ebnf = self.qwen25_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<tool_call>", ebnf)
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
if __name__ == "__main__":
unittest.main()

View File

@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
def test_function_call_required(self):
"""
Test: Whether tool_choice: "required" works as expected
- When tool_choice == "required", the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice="required",
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
self.assertIn(
"Paris", args_obj["city"], "Parameter city should contain 'Paris'"
) # might be flaky
def test_function_call_specific(self):
"""
Test: Whether tool_choice: ToolChoice works as expected
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [
@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsInstance(tool_calls, list)
self.assertIsInstance(tool_calls, list, "No tool_calls found")
self.assertGreaterEqual(len(tool_calls), 1)
names = [tc.function.name for tc in tool_calls]
self.assertIn("get_weather", names)
self.assertIn("get_tourist_attractions", names)
self.assertTrue(
"get_weather" in names or "get_tourist_attractions" in names,
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
def test_pythonic_tool_call_streaming(self):
"""
@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
self.assertTrue(found_index, "No index field found in any streamed tool_call")
self.assertIn("get_weather", found_names)
self.assertIn("get_tourist_attractions", found_names)
self.assertTrue(
"get_weather" in found_names or "get_tourist_attractions" in found_names,
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
if __name__ == "__main__":