856 lines
34 KiB
Python
856 lines
34 KiB
Python
"""
|
|
Test script for tool_choice functionality in SGLang
|
|
Tests: required, auto, and specific function choices in both streaming and non-streaming modes
|
|
|
|
# To run the tests, use the following command:
|
|
#
|
|
# python3 -m unittest openai_server.function_call.test_tool_choice
|
|
"""
|
|
|
|
import json
|
|
import unittest
|
|
|
|
import openai
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestToolChoiceLlama32(CustomTestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Mark flaky tests for this model
|
|
cls.flaky_tests = {
|
|
"test_multi_tool_scenario_auto",
|
|
"test_multi_tool_scenario_required",
|
|
}
|
|
|
|
# Use a model that supports function calling
|
|
cls.model = "meta-llama/Llama-3.2-1B-Instruct"
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
|
|
# Start the local OpenAI Server with tool calling support
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=[
|
|
"--tool-call-parser",
|
|
"llama3", # Default parser for the test model
|
|
],
|
|
)
|
|
cls.base_url += "/v1"
|
|
cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def setUp(self):
|
|
self.client = openai.Client(base_url=self.base_url, api_key=self.api_key)
|
|
self.model_name = self.client.models.list().data[0].id
|
|
|
|
def _is_flaky_test(self):
|
|
"""Check if the current test is marked as flaky for this class"""
|
|
return (
|
|
hasattr(self.__class__, "flaky_tests")
|
|
and self._testMethodName in self.__class__.flaky_tests
|
|
)
|
|
|
|
def get_test_tools(self):
|
|
"""Get the test tools for function calling"""
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "use this to get latest weather information for a city given its name",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "name of the city to get weather for",
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
},
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_pokemon_info",
|
|
"description": "get detailed information about a pokemon given its name",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string",
|
|
"description": "name of the pokemon to get info for",
|
|
}
|
|
},
|
|
"required": ["name"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "make_next_step_decision",
|
|
"description": "You will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools. \n You should never call the same tool with the same input twice in a row.\n If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again.\nOBSERVATION: the result of the tool call, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information,\n or you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\n If the previous conversation history already contains the answer, respond with the answer right away.\n\n If no tools are configured, naturally mention this limitation while still being helpful. Briefly note that adding tools in the agent configuration would expand capabilities.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"decision": {
|
|
"type": "string",
|
|
"description": 'The next step to take, it must be either "TOOL" or "ANSWER". If the previous conversation history already contains the information that can be retrieved from the tool, you should not call the tool again. If there are no defined tools, you should not return "TOOL" in your response.',
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": 'The content of the next step. If the decision is "TOOL", this should be a short and concise reasoning of why you chose the tool, MUST include the tool name. If the decision is "ANSWER", this should be the answer to the question. If no tools are available, integrate this limitation conversationally without sounding scripted.',
|
|
},
|
|
},
|
|
"required": ["decision", "content"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
def get_test_messages(self):
|
|
"""Get test messages that should trigger tool usage"""
|
|
return [
|
|
{
|
|
"role": "user",
|
|
"content": "Answer the following questions as best you can:\n\nYou will be given a trace of thinking process in the following format.\n\nQuestion: the input question you must answer\nTOOL: think about what to do, and choose a tool to use ONLY IF there are defined tools\nOBSERVATION: the result of the tool call or the observation of the current task, NEVER include this in your response, this information will be provided\n... (this TOOL/OBSERVATION can repeat N times)\nANSWER: If you know the answer to the original question, require for more information, \nif the previous conversation history already contains the answer, \nor you don't know the answer and there are no defined tools or all available tools are not helpful, respond with the answer without mentioning anything else.\nYou may use light Markdown formatting to improve clarity (e.g. lists, **bold**, *italics*), but keep it minimal and unobtrusive.\n\nYour task is to respond with the next step to take, based on the traces, \nor answer the question if you have enough information.\n\nQuestion: what is the weather in top 5 populated cities in the US in celsius?\n\nTraces:\n\n\nThese are some additional instructions that you should follow:",
|
|
}
|
|
]
|
|
|
|
def get_travel_tools(self):
|
|
"""Get tools for travel assistant scenario that should trigger multiple tool calls"""
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get the current weather for a given location.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The name of the city or location.",
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
},
|
|
},
|
|
"required": ["location"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_tourist_attractions",
|
|
"description": "Get a list of top tourist attractions for a given city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The name of the city to find attractions for.",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
def get_travel_messages(self):
|
|
"""Get travel assistant messages that should trigger multiple tool calls"""
|
|
return [
|
|
{
|
|
"content": "You are a travel assistant providing real-time weather updates and top tourist attractions.",
|
|
"role": "system",
|
|
},
|
|
{
|
|
"content": "I'm planning a trip to Tokyo next week. What's the weather like? What are the most amazing sights?",
|
|
"role": "user",
|
|
},
|
|
]
|
|
|
|
def test_tool_choice_auto_non_streaming(self):
|
|
"""Test tool_choice='auto' in non-streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
stream=False,
|
|
)
|
|
|
|
self.assertIsNotNone(response.choices[0].message)
|
|
# With auto, tool calls are optional
|
|
|
|
def test_tool_choice_auto_streaming(self):
|
|
"""Test tool_choice='auto' in streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
stream=True,
|
|
)
|
|
|
|
# Collect streaming response
|
|
content_chunks = []
|
|
tool_call_chunks = []
|
|
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.content:
|
|
content_chunks.append(chunk.choices[0].delta.content)
|
|
elif chunk.choices[0].delta.tool_calls:
|
|
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
|
|
|
# Should complete without errors
|
|
self.assertIsInstance(content_chunks, list)
|
|
self.assertIsInstance(tool_call_chunks, list)
|
|
|
|
def test_tool_choice_required_non_streaming(self):
|
|
"""Test tool_choice='required' in non-streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
temperature=0.2,
|
|
tools=tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
# With required, we should get tool calls
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
self.assertIsNotNone(tool_calls)
|
|
self.assertGreater(len(tool_calls), 0)
|
|
|
|
def test_tool_choice_required_streaming(self):
|
|
"""Test tool_choice='required' in streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice="required",
|
|
stream=True,
|
|
)
|
|
|
|
# Collect streaming response
|
|
tool_call_chunks = []
|
|
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.tool_calls:
|
|
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
|
|
|
# With required, we should get tool call chunks
|
|
self.assertGreater(len(tool_call_chunks), 0)
|
|
|
|
def test_tool_choice_specific_function_non_streaming(self):
|
|
"""Test tool_choice with specific function in non-streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
stream=False,
|
|
)
|
|
|
|
# Should call the specific function
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
self.assertIsNotNone(tool_calls)
|
|
# Our messages ask the top 5 populated cities in the US, so the model could get 5 tool calls
|
|
self.assertGreaterEqual(len(tool_calls), 1)
|
|
for tool_call in tool_calls:
|
|
self.assertEqual(tool_call.function.name, "get_weather")
|
|
|
|
def test_tool_choice_specific_function_streaming(self):
|
|
"""Test tool_choice with specific function in streaming mode"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
tool_choice = {"type": "function", "function": {"name": "get_weather"}}
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
stream=True,
|
|
)
|
|
|
|
# Collect streaming response
|
|
tool_call_chunks = []
|
|
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.tool_calls:
|
|
tool_call_chunks.extend(chunk.choices[0].delta.tool_calls)
|
|
|
|
# Should get tool call chunks for the specific function
|
|
self.assertGreater(len(tool_call_chunks), 0)
|
|
|
|
# Find function name in chunks
|
|
found_name = None
|
|
for chunk in tool_call_chunks:
|
|
if chunk.function and chunk.function.name:
|
|
found_name = chunk.function.name
|
|
break
|
|
|
|
self.assertEqual(found_name, "get_weather")
|
|
|
|
def test_required_streaming_arguments_chunks_json(self):
|
|
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=1024,
|
|
temperature=0.1,
|
|
tools=tools,
|
|
tool_choice="required",
|
|
stream=True,
|
|
)
|
|
|
|
# Collect all tool call chunks and reconstruct complete tool calls
|
|
tool_calls_by_index = {}
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.tool_calls:
|
|
for tool_call_delta in chunk.choices[0].delta.tool_calls:
|
|
tool_index = tool_call_delta.index
|
|
|
|
# Initialize tool call if not seen before
|
|
if tool_index not in tool_calls_by_index:
|
|
tool_calls_by_index[tool_index] = {
|
|
"id": tool_call_delta.id,
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""},
|
|
}
|
|
|
|
# Update function name if present (first chunk)
|
|
if tool_call_delta.function and tool_call_delta.function.name:
|
|
tool_calls_by_index[tool_index]["function"][
|
|
"name"
|
|
] = tool_call_delta.function.name
|
|
|
|
# Accumulate arguments (all chunks)
|
|
if tool_call_delta.function and tool_call_delta.function.arguments:
|
|
tool_calls_by_index[tool_index]["function"][
|
|
"arguments"
|
|
] += tool_call_delta.function.arguments
|
|
|
|
self.assertGreater(len(tool_calls_by_index), 0)
|
|
|
|
# Validate that complete tool calls have valid JSON arguments
|
|
for tool_call in tool_calls_by_index.values():
|
|
self.assertIsNotNone(tool_call["function"]["name"])
|
|
self.assertIsNotNone(tool_call["function"]["arguments"])
|
|
|
|
# The complete arguments should be valid JSON
|
|
try:
|
|
args = json.loads(tool_call["function"]["arguments"])
|
|
self.assertIsInstance(args, dict)
|
|
except json.JSONDecodeError:
|
|
self.fail(
|
|
f"Invalid JSON in complete tool call arguments: {tool_call['function']['arguments']}"
|
|
)
|
|
|
|
def test_complex_parameters_required_non_streaming(self):
|
|
"""Validate complex nested parameter schemas in non-streaming required mode"""
|
|
complex_tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "analyze_data",
|
|
"description": "Analyze complex data structures",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"data": {
|
|
"type": "object",
|
|
"properties": {
|
|
"metrics": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
},
|
|
"config": {
|
|
"type": "object",
|
|
"properties": {
|
|
"threshold": {"type": "number"},
|
|
"enabled": {"type": "boolean"},
|
|
},
|
|
},
|
|
},
|
|
"required": ["metrics"],
|
|
},
|
|
"options": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"value": {"type": "string"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"required": ["data"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Analyze some data with metrics and configuration",
|
|
}
|
|
]
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=1024,
|
|
temperature=0.1,
|
|
tools=complex_tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
self.assertIsNotNone(tool_calls)
|
|
self.assertGreater(len(tool_calls), 0)
|
|
|
|
for tool_call in tool_calls:
|
|
self.assertEqual(tool_call.function.name, "analyze_data")
|
|
try:
|
|
args = json.loads(tool_call.function.arguments)
|
|
self.assertIsInstance(args, dict)
|
|
self.assertIn("data", args)
|
|
self.assertIsInstance(args["data"], dict)
|
|
except json.JSONDecodeError:
|
|
self.fail(
|
|
f"Invalid JSON in complex tool call arguments: {tool_call.function.arguments}"
|
|
)
|
|
|
|
def test_multi_tool_scenario_auto(self):
|
|
"""Test multi-tool scenario with tool_choice='auto'"""
|
|
tools = self.get_travel_tools()
|
|
messages = self.get_travel_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
temperature=0.2,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
stream=False,
|
|
)
|
|
|
|
# Should complete without errors
|
|
self.assertIsNotNone(response.choices[0].message)
|
|
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
expected_functions = {"get_weather", "get_tourist_attractions"}
|
|
|
|
if self._is_flaky_test():
|
|
# For flaky tests, just verify all called functions are available tools
|
|
if tool_calls:
|
|
available_names = [tool["function"]["name"] for tool in tools]
|
|
for call in tool_calls:
|
|
self.assertIn(call.function.name, available_names)
|
|
else:
|
|
# For non-flaky tests, enforce strict requirements
|
|
self.assertIsNotNone(tool_calls, "Expected tool calls but got none")
|
|
self.assertEqual(
|
|
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
|
)
|
|
|
|
called_functions = {call.function.name for call in tool_calls}
|
|
self.assertEqual(
|
|
called_functions,
|
|
expected_functions,
|
|
f"Expected functions {expected_functions}, got {called_functions}",
|
|
)
|
|
|
|
def test_multi_tool_scenario_required(self):
|
|
"""Test multi-tool scenario with tool_choice='required'"""
|
|
tools = self.get_travel_tools()
|
|
messages = self.get_travel_messages()
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
temperature=0.2,
|
|
tools=tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
# With required, we should get at least one tool call
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
self.assertIsNotNone(tool_calls)
|
|
self.assertGreater(len(tool_calls), 0)
|
|
|
|
# Verify all called functions are available tools
|
|
available_names = [tool["function"]["name"] for tool in tools]
|
|
expected_functions = {"get_weather", "get_tourist_attractions"}
|
|
|
|
for tool_call in tool_calls:
|
|
self.assertIsNotNone(tool_call.function.name)
|
|
self.assertIsNotNone(tool_call.function.arguments)
|
|
|
|
if self._is_flaky_test():
|
|
# For flaky tests, just ensure basic functionality works
|
|
self.assertGreater(
|
|
len(tool_calls),
|
|
0,
|
|
f"Expected at least 1 tool call, got {len(tool_calls)}",
|
|
)
|
|
for call in tool_calls:
|
|
self.assertIn(call.function.name, available_names)
|
|
else:
|
|
# For non-flaky tests, enforce strict requirements
|
|
self.assertEqual(
|
|
len(tool_calls), 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
|
)
|
|
|
|
called_functions = {call.function.name for call in tool_calls}
|
|
self.assertEqual(
|
|
called_functions,
|
|
expected_functions,
|
|
f"Expected functions {expected_functions}, got {called_functions}",
|
|
)
|
|
|
|
def test_error_handling_invalid_tool_choice(self):
|
|
"""Test error handling for invalid tool_choice"""
|
|
tools = self.get_test_tools()
|
|
messages = self.get_test_messages()
|
|
|
|
# Test with invalid function name
|
|
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
|
|
|
|
# Expect a 400 BadRequestError to be raised for invalid tool_choice
|
|
with self.assertRaises(openai.BadRequestError) as context:
|
|
self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=2048,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
stream=False,
|
|
)
|
|
|
|
# Verify the error message contains the expected text
|
|
self.assertIn(
|
|
"Tool 'nonexistent_function' not found in tools list",
|
|
str(context.exception),
|
|
)
|
|
|
|
def test_invalid_tool_missing_name(self):
|
|
"""Test what happens when user doesn't provide a tool name in request"""
|
|
# Test with malformed JSON in tool parameters - missing required "name" field
|
|
invalid_tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
# Missing required "name" field
|
|
"description": "Test function with invalid schema",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"test_field": {
|
|
"type": "string",
|
|
"description": "Test field",
|
|
}
|
|
},
|
|
"required": ["test_field"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Test the function",
|
|
}
|
|
]
|
|
|
|
# Should raise BadRequestError due to missing required 'name' field
|
|
with self.assertRaises(openai.BadRequestError) as context:
|
|
self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=100,
|
|
temperature=0.1,
|
|
tools=invalid_tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
# Verify the error message indicates missing name field
|
|
error_msg = str(context.exception).lower()
|
|
self.assertIn("name", error_msg)
|
|
|
|
def test_invalid_json_schema_in_tool(self):
|
|
"""Test what happens when tool function has invalid JSON schema"""
|
|
invalid_tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "test_function",
|
|
"description": "Test function with invalid JSON schema",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"invalid_field": {
|
|
"type": "unknown_type", # Invalid type
|
|
"description": "This field has an invalid type",
|
|
}
|
|
},
|
|
"required": ["invalid_field"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Test the function",
|
|
}
|
|
]
|
|
|
|
# Should raise BadRequestError due to invalid JSON schema in tool parameters
|
|
with self.assertRaises(openai.BadRequestError) as context:
|
|
self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=100,
|
|
temperature=0.1,
|
|
tools=invalid_tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
# Verify the error message indicates invalid JSON schema for parameters field
|
|
error_msg = str(context.exception).lower()
|
|
self.assertIn("invalid 'parameters' schema", error_msg)
|
|
|
|
def test_conflicting_defs_required_tool_choice(self):
|
|
"""Test that conflicting $defs with required tool_choice returns 400 error"""
|
|
conflicting_tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "tool1",
|
|
"description": "Tool 1 with conflicting $defs",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"data": {"$ref": "#/$defs/DataType"},
|
|
},
|
|
"required": ["data"],
|
|
"$defs": {
|
|
"DataType": {
|
|
"type": "object",
|
|
"properties": {"value": {"type": "string"}},
|
|
"required": ["value"],
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "tool2",
|
|
"description": "Tool 2 with conflicting $defs",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"data": {"$ref": "#/$defs/DataType"},
|
|
},
|
|
"required": ["data"],
|
|
"$defs": {
|
|
"DataType": { # Different definition for DataType
|
|
"type": "object",
|
|
"properties": {"value": {"type": "number"}},
|
|
"required": ["value"],
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Test the conflicting tools",
|
|
}
|
|
]
|
|
|
|
# Should raise BadRequestError due to conflicting $defs
|
|
with self.assertRaises(openai.BadRequestError) as context:
|
|
self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=messages,
|
|
max_tokens=100,
|
|
temperature=0.1,
|
|
tools=conflicting_tools,
|
|
tool_choice="required",
|
|
stream=False,
|
|
)
|
|
|
|
# Verify the error message indicates conflicting tool definitions
|
|
error_msg = str(context.exception).lower()
|
|
self.assertIn("multiple schemas", error_msg)
|
|
self.assertIn("not supported", error_msg)
|
|
|
|
|
|
class TestToolChoiceQwen25(TestToolChoiceLlama32):
|
|
"""Test tool_choice functionality with Qwen2.5 model"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.flaky_tests = {}
|
|
|
|
cls.model = "Qwen/Qwen2.5-7B-Instruct"
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=[
|
|
"--tool-call-parser",
|
|
"qwen25",
|
|
],
|
|
)
|
|
cls.base_url += "/v1"
|
|
cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
|
|
class TestToolChoiceMistral(TestToolChoiceLlama32):
|
|
"""Test tool_choice functionality with Mistral model"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Mark flaky tests for this model
|
|
cls.flaky_tests = {
|
|
"test_multi_tool_scenario_auto",
|
|
"test_multi_tool_scenario_required",
|
|
}
|
|
|
|
cls.model = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=[
|
|
"--tool-call-parser",
|
|
"mistral",
|
|
],
|
|
)
|
|
cls.base_url += "/v1"
|
|
cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
|
|
def test_multi_tool_scenario_required(self):
|
|
"""Test multi-tool scenario with tool_choice='required'"""
|
|
super().test_multi_tool_scenario_required()
|
|
|
|
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
|
|
def test_complex_parameters_required_non_streaming(self):
|
|
"""Validate complex nested parameter schemas in non-streaming required mode"""
|
|
super().test_complex_parameters_required_non_streaming()
|
|
|
|
|
|
# Skip for ci test
|
|
# class TestToolChoiceGLM45(TestToolChoiceLlama32):
|
|
# @classmethod
|
|
# def setUpClass(cls):
|
|
# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
# cls.model = "THUDM/GLM-4.5"
|
|
# cls.base_url = DEFAULT_URL_FOR_TEST
|
|
# cls.api_key = "sk-123456"
|
|
|
|
# # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
|
|
# cls.process = popen_launch_server(
|
|
# cls.model,
|
|
# cls.base_url,
|
|
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
# api_key=cls.api_key,
|
|
# other_args=[
|
|
# # If your server needs extra parameters to test function calling, please add them here.
|
|
# "--tool-call-parser",
|
|
# "glm45",
|
|
# "--reasoning-parser",
|
|
# "glm45",
|
|
# "--tp-size",
|
|
# "8"
|
|
# ],
|
|
# )
|
|
# cls.base_url += "/v1"
|
|
# cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|