bugfix(tool call ebnf): Fix EBNF generation for optional function parameters (#7283)
This commit is contained in:
@@ -211,20 +211,74 @@ class EBNFComposer:
|
|||||||
properties = params.get("properties", {})
|
properties = params.get("properties", {})
|
||||||
required_props = set(params.get("required", []))
|
required_props = set(params.get("required", []))
|
||||||
|
|
||||||
# Build argument rules for this tool
|
# The generated pattern ensures:
|
||||||
arg_rules = []
|
# 1. Required properties appear first, joined by commas
|
||||||
|
# 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
|
||||||
|
# 3. For multiple optional properties, we allow flexible ordering:
|
||||||
|
# - Each optional can be skipped entirely
|
||||||
|
# - They can appear in any combination
|
||||||
|
#
|
||||||
|
# Example patterns generated:
|
||||||
|
# - One required, one optional:
|
||||||
|
# "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
|
||||||
|
# Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
|
||||||
|
#
|
||||||
|
# - Multiple optional properties with flexible ordering:
|
||||||
|
# "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
|
||||||
|
# Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
|
||||||
|
# {"req": "x", "opt1": "y", "opt2": "z"}
|
||||||
|
#
|
||||||
|
# - All optional properties with flexible ordering:
|
||||||
|
# "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
|
||||||
|
# Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
|
||||||
|
|
||||||
|
prop_kv_pairs = {}
|
||||||
|
ordered_props = list(properties.keys())
|
||||||
|
|
||||||
for prop_name, prop_schema in properties.items():
|
for prop_name, prop_schema in properties.items():
|
||||||
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
|
||||||
# Create key=value pair
|
# Create key=value pair
|
||||||
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
pair = key_value_template.format(key=prop_name, valrule=value_rule)
|
||||||
|
prop_kv_pairs[prop_name] = pair
|
||||||
|
|
||||||
if prop_name not in required_props:
|
# Separate into required and optional while preserving order
|
||||||
pair = f"[ {pair} ]"
|
required = [p for p in ordered_props if p in required_props]
|
||||||
|
optional = [p for p in ordered_props if p not in required_props]
|
||||||
|
|
||||||
arg_rules.append(pair)
|
# Build the combined rule
|
||||||
|
rule_parts = []
|
||||||
|
|
||||||
# Combine all argument rules
|
# Add required properties joined by commas
|
||||||
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
|
if required:
|
||||||
|
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
|
||||||
|
|
||||||
|
# Add optional properties with flexible ordering
|
||||||
|
if optional:
|
||||||
|
# Build alternatives where any optional property can appear first
|
||||||
|
opt_alternatives = []
|
||||||
|
for i in range(len(optional)):
|
||||||
|
# Build pattern for optional[i] appearing first
|
||||||
|
opt_parts = []
|
||||||
|
for j in range(i, len(optional)):
|
||||||
|
if j == i:
|
||||||
|
opt_parts.append(prop_kv_pairs[optional[j]])
|
||||||
|
else:
|
||||||
|
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
|
||||||
|
opt_alternatives.append("".join(opt_parts))
|
||||||
|
|
||||||
|
# Wrap with appropriate comma handling based on whether we have required properties
|
||||||
|
if required:
|
||||||
|
# Required properties exist, so optional group needs outer comma
|
||||||
|
rule_parts.append(' ( "," ( ')
|
||||||
|
rule_parts.append(" | ".join(opt_alternatives))
|
||||||
|
rule_parts.append(" ) )?")
|
||||||
|
else:
|
||||||
|
# All properties are optional
|
||||||
|
rule_parts.append("( ")
|
||||||
|
rule_parts.append(" | ".join(opt_alternatives))
|
||||||
|
rule_parts.append(" )?")
|
||||||
|
|
||||||
|
combined_args = "".join(rule_parts)
|
||||||
arguments_rule = args_template.format(arg_rules=combined_args)
|
arguments_rule = args_template.format(arg_rules=combined_args)
|
||||||
|
|
||||||
# Add the function call rule and its arguments rule
|
# Add the function call rule and its arguments rule
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
|
||||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
|
||||||
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_all_optional_function_params
|
||||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
|
||||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
|
||||||
|
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_all_optional_function_params
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -237,6 +239,38 @@ class TestEBNFConstrained(CustomTestCase):
|
|||||||
n=3,
|
n=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_ebnf_generate_all_optional_function_params(self):
|
||||||
|
"""Test function call with all optional parameters - verifies flexible ordering."""
|
||||||
|
self.__class__.ebnf_grammar = """
|
||||||
|
root ::= function_call
|
||||||
|
function_call ::= call_config_service
|
||||||
|
call_config_service ::= "{" "\\"name\\"" ":" "\\"config_service\\"" ", " "\\"arguments\\"" ":" arguments_config_service "}"
|
||||||
|
arguments_config_service ::= "{" ( "\\"theme\\"" ":" ("\\"light\\"" | "\\"dark\\"") ( "," "\\"language\\"" ":" ("\\"en\\"" | "\\"es\\"" | "\\"fr\\"") )? ( "," "\\"notifications\\"" ":" ("true" | "false") )? | "\\"language\\"" ":" ("\\"en\\"" | "\\"es\\"" | "\\"fr\\"") ( "," "\\"notifications\\"" ":" ("true" | "false") )? | "\\"notifications\\"" ":" ("true" | "false") )? "}"
|
||||||
|
"""
|
||||||
|
# Test patterns that should match - flexible ordering of optional parameters
|
||||||
|
allowed_patterns = [
|
||||||
|
# Empty arguments
|
||||||
|
r'^\{"name":"config_service", "arguments":\{\}\}$',
|
||||||
|
# Single optional parameters (any can appear first)
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)"\}\}$',
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"language":"(en|es|fr)"\}\}$',
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"notifications":(true|false)\}\}$',
|
||||||
|
# Two optional parameters (in any order)
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "language":"(en|es|fr)"\}\}$',
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "notifications":(true|false)\}\}$',
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"language":"(en|es|fr)", "notifications":(true|false)\}\}$',
|
||||||
|
# All three optional parameters
|
||||||
|
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "language":"(en|es|fr)", "notifications":(true|false)\}\}$',
|
||||||
|
]
|
||||||
|
prompt = "Configure the service with optional settings:"
|
||||||
|
|
||||||
|
self.run_decode(
|
||||||
|
ebnf=self.__class__.ebnf_grammar,
|
||||||
|
expected_patterns=allowed_patterns,
|
||||||
|
prompt=prompt,
|
||||||
|
n=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
|
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -515,7 +515,7 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
# Check that the EBNF contains expected patterns
|
# Check that the EBNF contains expected patterns
|
||||||
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
|
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
|
||||||
self.assertIn('"location" "=" basic_string', ebnf)
|
self.assertIn('"location" "=" basic_string', ebnf)
|
||||||
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
|
self.assertIn('( "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") )', ebnf)
|
||||||
|
|
||||||
# Validate that the EBNF can be compiled by GrammarCompiler
|
# Validate that the EBNF can be compiled by GrammarCompiler
|
||||||
try:
|
try:
|
||||||
@@ -591,6 +591,224 @@ class TestEBNFGeneration(unittest.TestCase):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
self.fail(f"Failed to compile EBNF: {e}")
|
self.fail(f"Failed to compile EBNF: {e}")
|
||||||
|
|
||||||
|
def test_weather_function_optional_parameter_handling(self):
|
||||||
|
"""Test that weather function with optional unit parameter generates correct EBNF without trailing commas."""
|
||||||
|
# Create a weather tool with required location and optional unit
|
||||||
|
weather_tool = Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_current_weather",
|
||||||
|
description="Get the current weather in a given location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test all detectors with the weather tool
|
||||||
|
detectors = {
|
||||||
|
"pythonic": self.pythonic_detector,
|
||||||
|
"deepseekv3": self.deepseekv3_detector,
|
||||||
|
"llama32": self.llama32_detector,
|
||||||
|
"mistral": self.mistral_detector,
|
||||||
|
"qwen25": self.qwen25_detector,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, detector in detectors.items():
|
||||||
|
with self.subTest(detector=name):
|
||||||
|
ebnf = detector.build_ebnf([weather_tool])
|
||||||
|
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
|
||||||
|
|
||||||
|
# Check that the EBNF properly handles optional parameters
|
||||||
|
if name == "pythonic":
|
||||||
|
# Pythonic format: location="Paris" ( , ( unit=("celsius" | "fahrenheit") )?
|
||||||
|
self.assertIn('"location" "=" basic_string', ebnf)
|
||||||
|
# The comma should be inside the optional brackets for unit
|
||||||
|
self.assertIn('( "," ( "unit" "=" ', ebnf)
|
||||||
|
else:
|
||||||
|
# JSON format: "location": "Paris" ( , ( "unit": ("celsius" | "fahrenheit") )?
|
||||||
|
self.assertIn('"location\\"" ":" basic_string', ebnf)
|
||||||
|
# The comma should be part of the optional group
|
||||||
|
# This pattern ensures no trailing comma when unit is omitted
|
||||||
|
self.assertIn('( "," ( "\\"unit\\"" ":"', ebnf)
|
||||||
|
|
||||||
|
# Validate that the EBNF can be compiled
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
ctx, f"{name} EBNF should compile successfully"
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile {name} EBNF: {e}")
|
||||||
|
|
||||||
|
def test_multiple_optional_parameters_flexible_ordering(self):
|
||||||
|
"""Test that multiple optional parameters allow flexible ordering using llama.cpp approach."""
|
||||||
|
# Create a tool with one required and multiple optional parameters
|
||||||
|
test_tool = Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="test_func",
|
||||||
|
description="Test function with multiple optional parameters",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"required_field": {"type": "string"},
|
||||||
|
"opt1": {"type": "number"},
|
||||||
|
"opt2": {"type": "boolean"},
|
||||||
|
"opt3": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["required_field"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test JSON-based detectors (not pythonic)
|
||||||
|
json_detectors = {
|
||||||
|
"deepseekv3": self.deepseekv3_detector,
|
||||||
|
"llama32": self.llama32_detector,
|
||||||
|
"mistral": self.mistral_detector,
|
||||||
|
"qwen25": self.qwen25_detector,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, detector in json_detectors.items():
|
||||||
|
with self.subTest(detector=name):
|
||||||
|
ebnf = detector.build_ebnf([test_tool])
|
||||||
|
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
|
||||||
|
|
||||||
|
# Print the arguments rule for debugging
|
||||||
|
lines = ebnf.split("\n")
|
||||||
|
args_rule = None
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("arguments_test_func ::="):
|
||||||
|
args_rule = line
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertIsNotNone(
|
||||||
|
args_rule, f"{name} should have arguments_test_func rule"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check required field
|
||||||
|
self.assertIn('"required_field\\"" ":" basic_string', ebnf)
|
||||||
|
|
||||||
|
# Check the structure for optional parameters
|
||||||
|
# The pattern should be: required_field ( "," ( opt1 ... | opt2 ... | opt3 ... ) )?
|
||||||
|
# This allows flexible ordering where any optional can be first
|
||||||
|
|
||||||
|
# Check that optional parameters are in a group with comma
|
||||||
|
if args_rule: # Only check if args_rule was found
|
||||||
|
self.assertIn(
|
||||||
|
'( ","',
|
||||||
|
args_rule,
|
||||||
|
f"{name} should have comma grouped with optional parameters",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for the alternation pattern that allows flexible ordering
|
||||||
|
# Should contain patterns like: opt1 ... | opt2 ... | opt3
|
||||||
|
self.assertIn('"opt1\\"" ":" basic_number', args_rule)
|
||||||
|
self.assertIn('"opt2\\"" ":" basic_boolean', args_rule)
|
||||||
|
self.assertIn('"opt3\\"" ":" basic_string', args_rule)
|
||||||
|
|
||||||
|
# Check for alternation (|) which allows skipping optional parameters
|
||||||
|
self.assertIn(
|
||||||
|
"|",
|
||||||
|
args_rule,
|
||||||
|
f"{name} should use alternation for flexible optional ordering",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the pattern ends properly with closing braces
|
||||||
|
self.assertTrue(
|
||||||
|
args_rule.endswith('"}"'),
|
||||||
|
f"{name} arguments rule should end with closing brace",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate compilation
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
ctx, f"{name} EBNF should compile successfully"
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile {name} EBNF: {e}")
|
||||||
|
|
||||||
|
def test_all_optional_parameters_ordering(self):
|
||||||
|
"""Test the behavior when ALL parameters are optional - verifies ordering constraints."""
|
||||||
|
# Create a tool with only optional parameters
|
||||||
|
all_optional_tool = Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="optional_func",
|
||||||
|
description="Function with all optional parameters",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"opt1": {"type": "string"},
|
||||||
|
"opt2": {"type": "number"},
|
||||||
|
"opt3": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
"required": [], # No required parameters
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test JSON-based detectors
|
||||||
|
json_detectors = {
|
||||||
|
"deepseekv3": self.deepseekv3_detector,
|
||||||
|
"llama32": self.llama32_detector,
|
||||||
|
"mistral": self.mistral_detector,
|
||||||
|
"qwen25": self.qwen25_detector,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, detector in json_detectors.items():
|
||||||
|
with self.subTest(detector=name):
|
||||||
|
ebnf = detector.build_ebnf([all_optional_tool])
|
||||||
|
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
|
||||||
|
|
||||||
|
# Extract the arguments rule
|
||||||
|
lines = ebnf.split("\n")
|
||||||
|
args_rule = None
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("arguments_optional_func ::="):
|
||||||
|
args_rule = line
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertIsNotNone(
|
||||||
|
args_rule, f"{name} should have arguments_optional_func rule"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args_rule:
|
||||||
|
# When all parameters are optional, the pattern now uses alternation:
|
||||||
|
# "{" ( opt1 ... | opt2 ... | opt3 ... )? "}"
|
||||||
|
# This allows flexible ordering where any optional can appear first
|
||||||
|
|
||||||
|
# Check the structure
|
||||||
|
self.assertIn('"opt1\\"" ":" basic_string', args_rule)
|
||||||
|
self.assertIn('"opt2\\"" ":" basic_number', args_rule)
|
||||||
|
self.assertIn('"opt3\\"" ":" basic_boolean', args_rule)
|
||||||
|
|
||||||
|
# The pattern SHOULD have alternation (|) for flexible ordering
|
||||||
|
self.assertIn(
|
||||||
|
"|",
|
||||||
|
args_rule,
|
||||||
|
f"{name} should use alternation for flexible ordering even when all properties are optional",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate compilation
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(ebnf)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
ctx, f"{name} EBNF should compile successfully"
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
self.fail(f"Failed to compile {name} EBNF: {e}")
|
||||||
|
|
||||||
|
|
||||||
class TestBaseFormatDetector(unittest.TestCase):
|
class TestBaseFormatDetector(unittest.TestCase):
|
||||||
"""Test buffer management and sequential tool index assignment in BaseFormatDetector."""
|
"""Test buffer management and sequential tool index assignment in BaseFormatDetector."""
|
||||||
|
|||||||
@@ -77,7 +77,11 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
"city": {
|
"city": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "name of the city to get weather for",
|
"description": "name of the city to get weather for",
|
||||||
}
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["city"],
|
"required": ["city"],
|
||||||
},
|
},
|
||||||
@@ -152,7 +156,7 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
"enum": ["celsius", "fahrenheit"],
|
"enum": ["celsius", "fahrenheit"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "unit"],
|
"required": ["location"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user