Use jsonschema to constrain required or specific tool choice (#10550)

This commit is contained in:
Tejesh Anand
2025-09-27 10:18:50 -07:00
committed by GitHub
parent 9c339d6b47
commit 8cc27fdc46
12 changed files with 1558 additions and 50 deletions

View File

@@ -73,11 +73,11 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"type": "object",
"properties": {
"a": {
"type": "int",
"type": "integer",
"description": "A number",
},
"b": {
"type": "int",
"type": "integer",
"description": "A number",
},
},
@@ -128,11 +128,11 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
"type": "object",
"properties": {
"a": {
"type": "int",
"type": "integer",
"description": "A number",
},
"b": {
"type": "int",
"type": "integer",
"description": "A number",
},
},

View File

@@ -343,6 +343,142 @@ class TestToolChoiceLlama32(CustomTestCase):
self.assertEqual(found_name, "get_weather")
def test_required_streaming_arguments_chunks_json(self):
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
tools = self.get_test_tools()
messages = self.get_test_messages()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=1024,
temperature=0.1,
tools=tools,
tool_choice="required",
stream=True,
)
# Collect all tool call chunks and reconstruct complete tool calls
tool_calls_by_index = {}
for chunk in response:
if chunk.choices[0].delta.tool_calls:
for tool_call_delta in chunk.choices[0].delta.tool_calls:
tool_index = tool_call_delta.index
# Initialize tool call if not seen before
if tool_index not in tool_calls_by_index:
tool_calls_by_index[tool_index] = {
"id": tool_call_delta.id,
"type": "function",
"function": {"name": "", "arguments": ""},
}
# Update function name if present (first chunk)
if tool_call_delta.function and tool_call_delta.function.name:
tool_calls_by_index[tool_index]["function"][
"name"
] = tool_call_delta.function.name
# Accumulate arguments (all chunks)
if tool_call_delta.function and tool_call_delta.function.arguments:
tool_calls_by_index[tool_index]["function"][
"arguments"
] += tool_call_delta.function.arguments
self.assertGreater(len(tool_calls_by_index), 0)
# Validate that complete tool calls have valid JSON arguments
for tool_call in tool_calls_by_index.values():
self.assertIsNotNone(tool_call["function"]["name"])
self.assertIsNotNone(tool_call["function"]["arguments"])
# The complete arguments should be valid JSON
try:
args = json.loads(tool_call["function"]["arguments"])
self.assertIsInstance(args, dict)
except json.JSONDecodeError:
self.fail(
f"Invalid JSON in complete tool call arguments: {tool_call['function']['arguments']}"
)
def test_complex_parameters_required_non_streaming(self):
"""Validate complex nested parameter schemas in non-streaming required mode"""
complex_tools = [
{
"type": "function",
"function": {
"name": "analyze_data",
"description": "Analyze complex data structures",
"parameters": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {"type": "string"},
},
"config": {
"type": "object",
"properties": {
"threshold": {"type": "number"},
"enabled": {"type": "boolean"},
},
},
},
"required": ["metrics"],
},
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "string"},
},
},
},
},
"required": ["data"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Analyze some data with metrics and configuration",
}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=1024,
temperature=0.1,
tools=complex_tools,
tool_choice="required",
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls)
self.assertGreater(len(tool_calls), 0)
for tool_call in tool_calls:
self.assertEqual(tool_call.function.name, "analyze_data")
try:
args = json.loads(tool_call.function.arguments)
self.assertIsInstance(args, dict)
self.assertIn("data", args)
self.assertIsInstance(args["data"], dict)
except json.JSONDecodeError:
self.fail(
f"Invalid JSON in complex tool call arguments: {tool_call.function.arguments}"
)
def test_multi_tool_scenario_auto(self):
"""Test multi-tool scenario with tool_choice='auto'"""
tools = self.get_travel_tools()
@@ -408,6 +544,10 @@ class TestToolChoiceLlama32(CustomTestCase):
available_names = [tool["function"]["name"] for tool in tools]
expected_functions = {"get_weather", "get_tourist_attractions"}
for tool_call in tool_calls:
self.assertIsNotNone(tool_call.function.name)
self.assertIsNotNone(tool_call.function.arguments)
if self._is_flaky_test():
# For flaky tests, just ensure basic functionality works
self.assertGreater(
@@ -432,22 +572,15 @@ class TestToolChoiceLlama32(CustomTestCase):
def test_error_handling_invalid_tool_choice(self):
"""Test error handling for invalid tool_choice"""
import logging
from unittest.mock import patch
tools = self.get_test_tools()
messages = self.get_test_messages()
# Test with invalid function name
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
# The behavior could be either:
# 1. Log a warning and continue (if fallback is implemented)
# 2. Raise an exception (if strict validation is implemented)
# First try to capture any logging that might happen
with patch("logging.warning") as mock_warning:
response = self.client.chat.completions.create(
# Expect a 400 BadRequestError to be raised for invalid tool_choice
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=2048,
@@ -456,11 +589,173 @@ class TestToolChoiceLlama32(CustomTestCase):
stream=False,
)
self.assertIsNotNone(response.choices[0].message)
# Verify the error message contains the expected text
self.assertIn(
"Tool 'nonexistent_function' not found in tools list",
str(context.exception),
)
if mock_warning.called:
warning_message = mock_warning.call_args[0][0]
self.assertIn("nonexistent_function", warning_message)
def test_invalid_tool_missing_name(self):
"""Test what happens when user doesn't provide a tool name in request"""
# Test with malformed JSON in tool parameters - missing required "name" field
invalid_tools = [
{
"type": "function",
"function": {
# Missing required "name" field
"description": "Test function with invalid schema",
"parameters": {
"type": "object",
"properties": {
"test_field": {
"type": "string",
"description": "Test field",
}
},
"required": ["test_field"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Test the function",
}
]
# Should raise BadRequestError due to missing required 'name' field
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=100,
temperature=0.1,
tools=invalid_tools,
tool_choice="required",
stream=False,
)
# Verify the error message indicates missing name field
error_msg = str(context.exception).lower()
self.assertIn("name", error_msg)
def test_invalid_json_schema_in_tool(self):
"""Test what happens when tool function has invalid JSON schema"""
invalid_tools = [
{
"type": "function",
"function": {
"name": "test_function",
"description": "Test function with invalid JSON schema",
"parameters": {
"type": "object",
"properties": {
"invalid_field": {
"type": "unknown_type", # Invalid type
"description": "This field has an invalid type",
}
},
"required": ["invalid_field"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Test the function",
}
]
# Should raise BadRequestError due to invalid JSON schema in tool parameters
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=100,
temperature=0.1,
tools=invalid_tools,
tool_choice="required",
stream=False,
)
# Verify the error message indicates invalid JSON schema for parameters field
error_msg = str(context.exception).lower()
self.assertIn("invalid 'parameters' schema", error_msg)
def test_conflicting_defs_required_tool_choice(self):
"""Test that conflicting $defs with required tool_choice returns 400 error"""
conflicting_tools = [
{
"type": "function",
"function": {
"name": "tool1",
"description": "Tool 1 with conflicting $defs",
"parameters": {
"type": "object",
"properties": {
"data": {"$ref": "#/$defs/DataType"},
},
"required": ["data"],
"$defs": {
"DataType": {
"type": "object",
"properties": {"value": {"type": "string"}},
"required": ["value"],
},
},
},
},
},
{
"type": "function",
"function": {
"name": "tool2",
"description": "Tool 2 with conflicting $defs",
"parameters": {
"type": "object",
"properties": {
"data": {"$ref": "#/$defs/DataType"},
},
"required": ["data"],
"$defs": {
"DataType": { # Different definition for DataType
"type": "object",
"properties": {"value": {"type": "number"}},
"required": ["value"],
},
},
},
},
},
]
messages = [
{
"role": "user",
"content": "Test the conflicting tools",
}
]
# Should raise BadRequestError due to conflicting $defs
with self.assertRaises(openai.BadRequestError) as context:
self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=100,
temperature=0.1,
tools=conflicting_tools,
tool_choice="required",
stream=False,
)
# Verify the error message indicates conflicting tool definitions
error_msg = str(context.exception).lower()
self.assertIn("multiple schemas", error_msg)
self.assertIn("not supported", error_msg)
class TestToolChoiceQwen25(TestToolChoiceLlama32):
@@ -516,6 +811,16 @@ class TestToolChoiceMistral(TestToolChoiceLlama32):
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
def test_multi_tool_scenario_required(self):
"""Test multi-tool scenario with tool_choice='required'"""
super().test_multi_tool_scenario_required()
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
def test_complex_parameters_required_non_streaming(self):
"""Validate complex nested parameter schemas in non-streaming required mode"""
super().test_complex_parameters_required_non_streaming()
# Skip for ci test
# class TestToolChoiceGLM45(TestToolChoiceLlama32):