Use jsonschema to constrain required or specific tool choice (#10550)
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
@@ -392,7 +392,7 @@ class Function(BaseModel):
|
|||||||
"""Function descriptions."""
|
"""Function descriptions."""
|
||||||
|
|
||||||
description: Optional[str] = Field(default=None, examples=[None])
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
name: Optional[str] = None
|
name: str
|
||||||
parameters: Optional[object] = None
|
parameters: Optional[object] = None
|
||||||
strict: bool = False
|
strict: bool = False
|
||||||
|
|
||||||
@@ -943,6 +943,16 @@ class MessageProcessingResult:
|
|||||||
tool_call_constraint: Optional[Any] = None
|
tool_call_constraint: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallProcessingResult(NamedTuple):
|
||||||
|
"""Result of processing tool calls in a response."""
|
||||||
|
|
||||||
|
tool_calls: Optional[
|
||||||
|
List[Any]
|
||||||
|
] # List of ToolCall objects or None if parsing failed
|
||||||
|
remaining_text: str # Text remaining after parsing tool calls
|
||||||
|
finish_reason: Dict[str, Any] # Updated finish reason dictionary
|
||||||
|
|
||||||
|
|
||||||
class ResponseReasoningTextContent(BaseModel):
|
class ResponseReasoningTextContent(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
type: Literal["reasoning_text"] = "reasoning_text"
|
type: Literal["reasoning_text"] = "reasoning_text"
|
||||||
|
|||||||
@@ -62,6 +62,12 @@ class OpenAIServingBase(ABC):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||||
)
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(
|
||||||
|
message=str(e),
|
||||||
|
err_type="BadRequest",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in request: {e}")
|
logger.exception(f"Error in request: {e}")
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
from jsonschema import Draft202012Validator, SchemaError
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -25,6 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
LogProbs,
|
LogProbs,
|
||||||
MessageProcessingResult,
|
MessageProcessingResult,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolCallProcessingResult,
|
||||||
|
ToolChoice,
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
@@ -35,6 +38,8 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.function_call.core_types import ToolCallItem
|
from sglang.srt.function_call.core_types import ToolCallItem
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
|
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
||||||
|
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.parser.conversation import generate_chat_conv
|
from sglang.srt.parser.conversation import generate_chat_conv
|
||||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||||
@@ -75,6 +80,23 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
return "Tools cannot be empty if tool choice is set to required."
|
return "Tools cannot be empty if tool choice is set to required."
|
||||||
|
|
||||||
|
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
|
||||||
|
if not request.tools:
|
||||||
|
return "Tools cannot be empty if tool choice is set to a specific tool."
|
||||||
|
tool_name = request.tool_choice.function.name
|
||||||
|
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
|
||||||
|
if not tool_exists:
|
||||||
|
return f"Tool '{tool_name}' not found in tools list."
|
||||||
|
|
||||||
|
# Validate tool definitions
|
||||||
|
for i, tool in enumerate(request.tools or []):
|
||||||
|
if tool.function.parameters is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
Draft202012Validator.check_schema(tool.function.parameters)
|
||||||
|
except SchemaError as e:
|
||||||
|
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
|
||||||
|
|
||||||
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
||||||
server_context_length = self.tokenizer_manager.server_args.context_length
|
server_context_length = self.tokenizer_manager.server_args.context_length
|
||||||
if (
|
if (
|
||||||
@@ -190,6 +212,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
tool_call_constraint = parser.get_structure_constraint(
|
tool_call_constraint = parser.get_structure_constraint(
|
||||||
request.tool_choice
|
request.tool_choice
|
||||||
)
|
)
|
||||||
|
# Handle JSON schema constraint directly for required or named tool choice
|
||||||
|
if request.tool_choice == "required" or isinstance(
|
||||||
|
request.tool_choice, ToolChoice
|
||||||
|
):
|
||||||
|
json_schema = get_json_schema_constraint(
|
||||||
|
request.tools, request.tool_choice
|
||||||
|
)
|
||||||
|
tool_call_constraint = ("json_schema", json_schema)
|
||||||
|
|
||||||
# Use chat template
|
# Use chat template
|
||||||
if self.template_manager.chat_template_name is None:
|
if self.template_manager.chat_template_name is None:
|
||||||
@@ -437,6 +467,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
constraint_value.model_dump(by_alias=True)
|
constraint_value.model_dump(by_alias=True)
|
||||||
)
|
)
|
||||||
|
elif constraint_type == "json_schema":
|
||||||
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
|
constraint_value
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sampling_params[constraint_type] = constraint_value
|
sampling_params[constraint_type] = constraint_value
|
||||||
return sampling_params
|
return sampling_params
|
||||||
@@ -752,7 +786,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||||
text, request.tools, finish_reason, history_tool_calls_cnt
|
text,
|
||||||
|
request.tools,
|
||||||
|
finish_reason,
|
||||||
|
request.tool_choice,
|
||||||
|
history_tool_calls_cnt,
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
@@ -867,9 +905,51 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
text: str,
|
text: str,
|
||||||
tools: List[Any],
|
tools: List[Any],
|
||||||
finish_reason: Dict[str, Any],
|
finish_reason: Dict[str, Any],
|
||||||
|
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
||||||
history_tool_calls_cnt: int = 0,
|
history_tool_calls_cnt: int = 0,
|
||||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
) -> ToolCallProcessingResult:
|
||||||
"""Process tool calls in the response"""
|
"""Process tool calls in the response"""
|
||||||
|
|
||||||
|
# Handle required or named tool choice
|
||||||
|
if tool_choice == "required" or (
|
||||||
|
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
|
||||||
|
):
|
||||||
|
# Set finish reason to tool_calls since we're processing tool calls
|
||||||
|
if finish_reason["type"] == "stop":
|
||||||
|
finish_reason["type"] = "tool_calls"
|
||||||
|
finish_reason["matched"] = None
|
||||||
|
try:
|
||||||
|
# For required tool choice, we expect a JSON array of tool calls
|
||||||
|
tool_call_data = json.loads(text)
|
||||||
|
tool_calls = []
|
||||||
|
for i, tool in enumerate(tool_call_data):
|
||||||
|
# Create a ToolCallItem from the JSON data
|
||||||
|
call_info = ToolCallItem(
|
||||||
|
tool_index=i, # Use the loop index as tool_index
|
||||||
|
name=tool["name"],
|
||||||
|
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
tool_id = self._process_tool_call_id(
|
||||||
|
call_info, history_tool_calls_cnt
|
||||||
|
)
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tool_id,
|
||||||
|
index=i,
|
||||||
|
function=FunctionResponse(
|
||||||
|
name=tool["name"],
|
||||||
|
arguments=json.dumps(
|
||||||
|
tool["parameters"], ensure_ascii=False
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ToolCallProcessingResult(tool_calls, "", finish_reason)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Tool call parsing error: {e}")
|
||||||
|
return ToolCallProcessingResult(None, text, finish_reason)
|
||||||
|
|
||||||
|
# Use parser since output is not constrained by JSON schema
|
||||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||||
if parser.has_tool_call(text):
|
if parser.has_tool_call(text):
|
||||||
if finish_reason["type"] == "stop":
|
if finish_reason["type"] == "stop":
|
||||||
@@ -891,13 +971,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return tool_calls, text, finish_reason
|
return ToolCallProcessingResult(tool_calls, text, finish_reason)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tool call parsing error: {e}")
|
logger.error(f"Tool call parsing error: {e}")
|
||||||
# Return error but don't fail the whole request
|
# Return error but don't fail the whole request
|
||||||
return None, text, finish_reason
|
return ToolCallProcessingResult(None, text, finish_reason)
|
||||||
|
|
||||||
return None, text, finish_reason
|
return ToolCallProcessingResult(None, text, finish_reason)
|
||||||
|
|
||||||
def _process_streaming_logprobs(
|
def _process_streaming_logprobs(
|
||||||
self, content: Dict[str, Any], n_prev_token: int
|
self, content: Dict[str, Any], n_prev_token: int
|
||||||
@@ -990,13 +1070,25 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
"""Process tool calls in streaming response"""
|
"""Process tool calls in streaming response"""
|
||||||
if index not in parser_dict:
|
if index not in parser_dict:
|
||||||
parser_dict[index] = FunctionCallParser(
|
# Use JSON detector directly for required or named tool choice
|
||||||
tools=request.tools,
|
if request.tool_choice == "required" or isinstance(
|
||||||
tool_call_parser=self.tool_call_parser,
|
request.tool_choice, ToolChoice
|
||||||
)
|
):
|
||||||
|
parser_dict[index] = JsonArrayParser()
|
||||||
|
else:
|
||||||
|
parser_dict[index] = FunctionCallParser(
|
||||||
|
tools=request.tools,
|
||||||
|
tool_call_parser=self.tool_call_parser,
|
||||||
|
)
|
||||||
|
|
||||||
parser = parser_dict[index]
|
parser = parser_dict[index]
|
||||||
|
|
||||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
# Handle both FunctionCallParser and JsonArrayParser
|
||||||
|
if isinstance(parser, JsonArrayParser):
|
||||||
|
result = parser.parse_streaming_increment(delta, request.tools)
|
||||||
|
normal_text, calls = result.normal_text, result.calls
|
||||||
|
else:
|
||||||
|
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||||
|
|
||||||
# Yield normal text
|
# Yield normal text
|
||||||
if normal_text:
|
if normal_text:
|
||||||
@@ -1055,7 +1147,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
def _check_for_unstreamed_tool_args(
|
def _check_for_unstreamed_tool_args(
|
||||||
self,
|
self,
|
||||||
parser: FunctionCallParser,
|
parser: Union[FunctionCallParser, JsonArrayParser],
|
||||||
content: Dict[str, Any],
|
content: Dict[str, Any],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
index: int,
|
index: int,
|
||||||
@@ -1065,30 +1157,31 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
when generation finishes. This ensures tool calls are properly completed
|
when generation finishes. This ensures tool calls are properly completed
|
||||||
even if the model generates the final arguments in the last chunk.
|
even if the model generates the final arguments in the last chunk.
|
||||||
"""
|
"""
|
||||||
# Only check if we have tool calls and the parser has tracked data
|
# Get the detector - either from FunctionCallParser or directly if json detector
|
||||||
|
detector = parser.detector if hasattr(parser, "detector") else parser
|
||||||
|
|
||||||
|
# Only check if we have tool calls and the detector has tracked data
|
||||||
if (
|
if (
|
||||||
not hasattr(parser.detector, "prev_tool_call_arr")
|
not hasattr(detector, "prev_tool_call_arr")
|
||||||
or not parser.detector.prev_tool_call_arr
|
or not detector.prev_tool_call_arr
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not hasattr(parser.detector, "streamed_args_for_tool")
|
not hasattr(detector, "streamed_args_for_tool")
|
||||||
or not parser.detector.streamed_args_for_tool
|
or not detector.streamed_args_for_tool
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get the last tool call that was being processed
|
# Get the last tool call that was being processed
|
||||||
tool_index = len(parser.detector.prev_tool_call_arr) - 1
|
tool_index = len(detector.prev_tool_call_arr) - 1
|
||||||
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
|
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get expected vs actual arguments
|
# Get expected vs actual arguments
|
||||||
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
|
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
|
||||||
"arguments", {}
|
|
||||||
)
|
|
||||||
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
||||||
actual_call = parser.detector.streamed_args_for_tool[tool_index]
|
actual_call = detector.streamed_args_for_tool[tool_index]
|
||||||
|
|
||||||
# Check if there are remaining arguments to send
|
# Check if there are remaining arguments to send
|
||||||
remaining_call = (
|
remaining_call = (
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
|||||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||||
|
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -178,8 +179,8 @@ class FunctionCallParser:
|
|||||||
strict_tag = self.get_structure_tag()
|
strict_tag = self.get_structure_tag()
|
||||||
return ("structural_tag", strict_tag)
|
return ("structural_tag", strict_tag)
|
||||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||||
ebnf = self.get_ebnf(tool_choice)
|
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
return ("ebnf", ebnf) if ebnf is not None else None
|
return ("json_schema", json_schema)
|
||||||
|
|
||||||
def get_ebnf(
|
def get_ebnf(
|
||||||
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||||
|
|||||||
63
python/sglang/srt/function_call/json_array_parser.py
Normal file
63
python/sglang/srt/function_call/json_array_parser.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||||
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import StreamingParseResult
|
||||||
|
|
||||||
|
|
||||||
|
class JsonArrayParser(BaseFormatDetector):
|
||||||
|
"""
|
||||||
|
Parser for JSON array tool calls when JSON schema constraints are active.
|
||||||
|
|
||||||
|
This parser is used when tool_choice="required" or a specific tool is named,
|
||||||
|
bypassing model-specific parsers in favor of direct JSON array parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Configure for JSON array parsing
|
||||||
|
self.bot_token = "["
|
||||||
|
self.eot_token = "]"
|
||||||
|
self.tool_call_separator = ","
|
||||||
|
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given text contains a JSON tool call (array or single object).
|
||||||
|
"""
|
||||||
|
return "[" in text or "{" in text
|
||||||
|
|
||||||
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Parse JSON tool calls using the base class implementation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Detect and parse not supported for JSON schema constraints."
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||||
|
"""
|
||||||
|
Build an EBNF grammar for constrained generation.
|
||||||
|
This is not used for JSON schema constraints as they are handled
|
||||||
|
by the constraint backends directly.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EBNF generation is not supported for JSON schema constraints."
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_streaming_increment(
|
||||||
|
self, new_text: str, tools: List[Tool]
|
||||||
|
) -> StreamingParseResult:
|
||||||
|
"""
|
||||||
|
Streaming incremental parsing with tool validation.
|
||||||
|
"""
|
||||||
|
return super().parse_streaming_increment(new_text, tools)
|
||||||
|
|
||||||
|
def structure_info(self) -> callable:
|
||||||
|
"""
|
||||||
|
Return a function that creates StructureInfo for constrained generation.
|
||||||
|
This is not used for JSON schema constraints as they are handled
|
||||||
|
by the constraint backends directly.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("structure_info not used for JSON schema constraints")
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from json import JSONDecodeError, JSONDecoder
|
from json import JSONDecodeError, JSONDecoder
|
||||||
from typing import Any, Tuple
|
from json.decoder import WHITESPACE
|
||||||
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
|
||||||
|
|
||||||
|
|
||||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
@@ -37,10 +40,12 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||||
except JSONDecodeError as e:
|
except (JSONDecodeError, IndexError) as e:
|
||||||
if "Extra data" in e.msg:
|
msg = getattr(e, "msg", str(e))
|
||||||
dec = JSONDecoder()
|
if "Extra data" in msg or "pop from empty list" in msg:
|
||||||
return dec.raw_decode(input_str)
|
start = WHITESPACE.match(input_str, 0).end()
|
||||||
|
obj, end = JSONDecoder().raw_decode(input_str, start)
|
||||||
|
return obj, end
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -50,3 +55,89 @@ def _is_complete_json(input_str: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
|
||||||
|
"""
|
||||||
|
Get consolidated $defs from all tools, validating for conflicts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of consolidated $defs from all tools
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If conflicting $defs are found
|
||||||
|
"""
|
||||||
|
all_defs = {}
|
||||||
|
for tool in tools:
|
||||||
|
if tool.function.parameters is None:
|
||||||
|
continue
|
||||||
|
defs = tool.function.parameters.get("$defs", {})
|
||||||
|
for def_name, def_schema in defs.items():
|
||||||
|
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool definition '{def_name}' has "
|
||||||
|
"multiple schemas, which is not "
|
||||||
|
"supported."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_defs[def_name] = def_schema
|
||||||
|
return all_defs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_schema(tool: Tool) -> dict:
|
||||||
|
return {
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "enum": [tool.function.name]},
|
||||||
|
"parameters": (
|
||||||
|
tool.function.parameters
|
||||||
|
if tool.function.parameters
|
||||||
|
else {"type": "object", "properties": {}}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"required": ["name", "parameters"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_schema_constraint(
|
||||||
|
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get the JSON schema constraint for the specified tool choice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_choice: The tool choice specification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON schema dict, or None if no valid tools found
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(tool_choice, ToolChoice):
|
||||||
|
# For specific function choice, return the user's parameters schema directly
|
||||||
|
fn_name = tool_choice.function.name
|
||||||
|
for tool in tools:
|
||||||
|
if tool.function.name == fn_name:
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"maxItems": 1,
|
||||||
|
"items": _get_tool_schema(tool),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
elif tool_choice == "required":
|
||||||
|
json_schema = {
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1,
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"anyOf": [_get_tool_schema(tool) for tool in tools],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json_schema_defs = _get_tool_schema_defs(tools)
|
||||||
|
if json_schema_defs:
|
||||||
|
json_schema["$defs"] = json_schema_defs
|
||||||
|
return json_schema
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
618
test/srt/function_call/test_json_schema_constraint.py
Normal file
618
test/srt/function_call/test_json_schema_constraint.py
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
"""
|
||||||
|
Tests for JSON schema constraint functionality used by JsonArrayParser
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
Function,
|
||||||
|
Tool,
|
||||||
|
ToolChoice,
|
||||||
|
ToolChoiceFuncName,
|
||||||
|
)
|
||||||
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
|
from sglang.srt.function_call.utils import (
|
||||||
|
_get_tool_schema_defs,
|
||||||
|
get_json_schema_constraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonSchemaConstraint(unittest.TestCase):
|
||||||
|
"""Test JSON schema constraint generation for tool choices"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test tools"""
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location to get weather for",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "Temperature unit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_required_tool_choice_schema(self):
|
||||||
|
"""Test schema generation for tool_choice='required'"""
|
||||||
|
schema = get_json_schema_constraint(self.tools, "required")
|
||||||
|
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
self.assertEqual(schema["type"], "array")
|
||||||
|
self.assertEqual(schema["minItems"], 1)
|
||||||
|
self.assertIn("items", schema)
|
||||||
|
self.assertIn("anyOf", schema["items"])
|
||||||
|
|
||||||
|
# Should have schemas for both tools
|
||||||
|
self.assertEqual(len(schema["items"]["anyOf"]), 2)
|
||||||
|
|
||||||
|
# Check that each tool schema is present
|
||||||
|
tool_names = [
|
||||||
|
item["properties"]["name"]["enum"][0] for item in schema["items"]["anyOf"]
|
||||||
|
]
|
||||||
|
self.assertIn("get_weather", tool_names)
|
||||||
|
self.assertIn("search", tool_names)
|
||||||
|
|
||||||
|
def test_specific_tool_choice_schema(self):
|
||||||
|
"""Test schema generation for specific tool choice"""
|
||||||
|
tool_choice = ToolChoice(
|
||||||
|
type="function", function=ToolChoiceFuncName(name="get_weather")
|
||||||
|
)
|
||||||
|
schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
|
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
self.assertEqual(schema["type"], "array")
|
||||||
|
self.assertEqual(schema["minItems"], 1)
|
||||||
|
self.assertEqual(schema["maxItems"], 1)
|
||||||
|
|
||||||
|
# Should only have schema for the specific tool
|
||||||
|
item_schema = schema["items"]
|
||||||
|
self.assertEqual(item_schema["properties"]["name"]["enum"], ["get_weather"])
|
||||||
|
self.assertIn("parameters", item_schema["properties"])
|
||||||
|
|
||||||
|
def test_specific_tool_choice_dict_schema(self):
|
||||||
|
"""Test schema generation for specific tool choice as ToolChoice object"""
|
||||||
|
tool_choice = ToolChoice(
|
||||||
|
type="function", function=ToolChoiceFuncName(name="search")
|
||||||
|
)
|
||||||
|
schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
|
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
self.assertEqual(schema["type"], "array")
|
||||||
|
self.assertEqual(schema["minItems"], 1)
|
||||||
|
self.assertEqual(schema["maxItems"], 1)
|
||||||
|
|
||||||
|
# Should only have schema for the specific tool
|
||||||
|
item_schema = schema["items"]
|
||||||
|
self.assertEqual(item_schema["properties"]["name"]["enum"], ["search"])
|
||||||
|
self.assertIn("parameters", item_schema["properties"])
|
||||||
|
|
||||||
|
def test_nonexistent_tool_choice(self):
|
||||||
|
"""Test schema generation for nonexistent tool"""
|
||||||
|
tool_choice = ToolChoice(
|
||||||
|
type="function", function=ToolChoiceFuncName(name="nonexistent")
|
||||||
|
)
|
||||||
|
schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
|
|
||||||
|
self.assertIsNone(schema)
|
||||||
|
|
||||||
|
def test_nonexistent_tool_choice_dict(self):
|
||||||
|
"""Test schema generation for nonexistent tool as dict"""
|
||||||
|
tool_choice = {"type": "function", "function": {"name": "nonexistent"}}
|
||||||
|
schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
|
|
||||||
|
self.assertIsNone(schema)
|
||||||
|
|
||||||
|
def test_auto_tool_choice_schema(self):
|
||||||
|
"""Test schema generation for tool_choice='auto'"""
|
||||||
|
schema = get_json_schema_constraint(self.tools, "auto")
|
||||||
|
|
||||||
|
self.assertIsNone(schema)
|
||||||
|
|
||||||
|
def test_none_tool_choice_schema(self):
|
||||||
|
"""Test schema generation for tool_choice=None"""
|
||||||
|
schema = get_json_schema_constraint(self.tools, None)
|
||||||
|
|
||||||
|
self.assertIsNone(schema)
|
||||||
|
|
||||||
|
def test_tools_with_defs(self):
|
||||||
|
"""Test schema generation with tools that have $defs"""
|
||||||
|
tools_with_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="complex_tool",
|
||||||
|
description="Tool with complex schema",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nested": {"$ref": "#/$defs/NestedType"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"NestedType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(tools_with_defs)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(f"Should not raise ValueError, but got: {e}")
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(tools_with_defs, "required")
|
||||||
|
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
self.assertIn("$defs", schema)
|
||||||
|
self.assertIn("NestedType", schema["$defs"])
|
||||||
|
|
||||||
|
def test_tools_without_parameters(self):
|
||||||
|
"""Test schema generation with tools that have no parameters"""
|
||||||
|
tools_without_params = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="simple_tool",
|
||||||
|
description="Tool without parameters",
|
||||||
|
parameters=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(tools_without_params, "required")
|
||||||
|
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
item_schema = schema["items"]["anyOf"][0]
|
||||||
|
self.assertEqual(
|
||||||
|
item_schema["properties"]["parameters"],
|
||||||
|
{"type": "object", "properties": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_json_schema_vs_ebnf_constraint_generation(self):
|
||||||
|
"""Test direct comparison between JSON schema and EBNF constraint generation"""
|
||||||
|
|
||||||
|
# Test with specific tool choice
|
||||||
|
tool_choice = ToolChoice(
|
||||||
|
type="function", function=ToolChoiceFuncName(name="get_weather")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate JSON schema constraint
|
||||||
|
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||||
|
|
||||||
|
self.assertIsNotNone(json_schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(json_schema)
|
||||||
|
|
||||||
|
# Generate EBNF constraint using FunctionCallParser
|
||||||
|
parser = FunctionCallParser(
|
||||||
|
self.tools, "llama3"
|
||||||
|
) # Use a parser that supports EBNF
|
||||||
|
ebnf_constraint = parser.get_ebnf(tool_choice)
|
||||||
|
|
||||||
|
# Verify JSON schema constraint
|
||||||
|
self.assertEqual(json_schema["type"], "array")
|
||||||
|
self.assertEqual(json_schema["minItems"], 1)
|
||||||
|
self.assertEqual(json_schema["maxItems"], 1)
|
||||||
|
|
||||||
|
# Verify EBNF constraint
|
||||||
|
self.assertIsNotNone(ebnf_constraint)
|
||||||
|
self.assertIsInstance(ebnf_constraint, str)
|
||||||
|
self.assertIn("get_weather", ebnf_constraint)
|
||||||
|
|
||||||
|
# Test with required tool choice
|
||||||
|
required_json_schema = get_json_schema_constraint(self.tools, "required")
|
||||||
|
|
||||||
|
self.assertIsNotNone(required_json_schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(required_json_schema)
|
||||||
|
|
||||||
|
required_ebnf_constraint = parser.get_ebnf("required")
|
||||||
|
|
||||||
|
# Verify required JSON schema constraint
|
||||||
|
self.assertEqual(required_json_schema["type"], "array")
|
||||||
|
self.assertEqual(required_json_schema["minItems"], 1)
|
||||||
|
self.assertIn("anyOf", required_json_schema["items"])
|
||||||
|
|
||||||
|
# Verify required EBNF constraint
|
||||||
|
self.assertIsNotNone(required_ebnf_constraint)
|
||||||
|
self.assertIsInstance(required_ebnf_constraint, str)
|
||||||
|
|
||||||
|
# Both should contain references to the available tools
|
||||||
|
tool_names = [tool.function.name for tool in self.tools]
|
||||||
|
for tool_name in tool_names:
|
||||||
|
self.assertIn(tool_name, required_ebnf_constraint)
|
||||||
|
|
||||||
|
def test_conflicting_defs_raises_valueerror(self):
|
||||||
|
"""Test that conflicting tool definitions raise ValueError with proper message"""
|
||||||
|
tools_with_conflicting_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="tool1",
|
||||||
|
description="Tool 1",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"$defs": {
|
||||||
|
"ConflictingType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"value": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="tool2",
|
||||||
|
description="Tool 2",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"$defs": {
|
||||||
|
"ConflictingType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"value": {"type": "number"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
_get_tool_schema_defs(tools_with_conflicting_defs)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
"Tool definition 'ConflictingType' has multiple schemas",
|
||||||
|
str(context.exception),
|
||||||
|
)
|
||||||
|
self.assertIn("which is not supported", str(context.exception))
|
||||||
|
|
||||||
|
def test_tools_with_empty_defs(self):
|
||||||
|
"""Test tools with empty $defs objects"""
|
||||||
|
tools_with_empty_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="empty_defs_tool",
|
||||||
|
description="Tool with empty $defs",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
"$defs": {},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(tools_with_empty_defs)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(f"Should not raise ValueError, but got: {e}")
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(tools_with_empty_defs, "required")
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Should not have $defs section when empty
|
||||||
|
self.assertNotIn("$defs", schema)
|
||||||
|
|
||||||
|
def test_tools_with_identical_defs(self):
|
||||||
|
"""Test different tools with same $defs names but identical schemas (should not raise exception)"""
|
||||||
|
tools_with_identical_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="weather_tool",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"$ref": "#/$defs/Location"},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
"$defs": {
|
||||||
|
"Location": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"lat": {"type": "number"},
|
||||||
|
"lon": {"type": "number"},
|
||||||
|
},
|
||||||
|
"required": ["lat", "lon"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="address_tool",
|
||||||
|
description="Get address information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"address": {"$ref": "#/$defs/Location"},
|
||||||
|
},
|
||||||
|
"required": ["address"],
|
||||||
|
"$defs": {
|
||||||
|
"Location": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"lat": {"type": "number"},
|
||||||
|
"lon": {"type": "number"},
|
||||||
|
},
|
||||||
|
"required": ["lat", "lon"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(tools_with_identical_defs)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(
|
||||||
|
f"Should not raise ValueError for identical schemas, but got: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also test that schema generation works
|
||||||
|
schema = get_json_schema_constraint(tools_with_identical_defs, "required")
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Verify both tools are present
|
||||||
|
tool_names = [
|
||||||
|
item["properties"]["name"]["enum"][0] for item in schema["items"]["anyOf"]
|
||||||
|
]
|
||||||
|
self.assertIn("weather_tool", tool_names)
|
||||||
|
self.assertIn("address_tool", tool_names)
|
||||||
|
|
||||||
|
# Should have $defs with Location
|
||||||
|
self.assertIn("$defs", schema)
|
||||||
|
self.assertIn("Location", schema["$defs"])
|
||||||
|
|
||||||
|
def test_tools_with_nested_defs(self):
|
||||||
|
"""Test tools with nested $defs"""
|
||||||
|
tools_with_nested_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="complex_tool",
|
||||||
|
description="Tool with nested $defs",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user": {"$ref": "#/$defs/User"},
|
||||||
|
"settings": {"$ref": "#/$defs/Settings"},
|
||||||
|
},
|
||||||
|
"required": ["user"],
|
||||||
|
"$defs": {
|
||||||
|
"User": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"type": "string"},
|
||||||
|
"profile": {"$ref": "#/$defs/Profile"},
|
||||||
|
},
|
||||||
|
"required": ["id"],
|
||||||
|
},
|
||||||
|
"Profile": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"email": {"type": "string", "format": "email"},
|
||||||
|
},
|
||||||
|
"required": ["name"],
|
||||||
|
},
|
||||||
|
"Settings": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"theme": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["light", "dark"],
|
||||||
|
},
|
||||||
|
"notifications": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(tools_with_nested_defs)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(f"Should not raise ValueError, but got: {e}")
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(tools_with_nested_defs, "required")
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Verify all $defs are properly included
|
||||||
|
self.assertIn("$defs", schema)
|
||||||
|
self.assertIn("User", schema["$defs"])
|
||||||
|
self.assertIn("Profile", schema["$defs"])
|
||||||
|
self.assertIn("Settings", schema["$defs"])
|
||||||
|
|
||||||
|
def test_mixed_tools_with_and_without_defs(self):
|
||||||
|
"""Test mixed tools with and without $defs"""
|
||||||
|
mixed_tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="simple_tool",
|
||||||
|
description="Simple tool without $defs",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="complex_tool",
|
||||||
|
description="Complex tool with $defs",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"$ref": "#/$defs/DataType"},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
"$defs": {
|
||||||
|
"DataType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {"type": "string"},
|
||||||
|
"metadata": {"type": "object"},
|
||||||
|
},
|
||||||
|
"required": ["value"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="another_simple_tool",
|
||||||
|
description="Another simple tool",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["id"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(mixed_tools)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(f"Should not raise ValueError, but got: {e}")
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(mixed_tools, "required")
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Should have $defs from the complex tool
|
||||||
|
self.assertIn("$defs", schema)
|
||||||
|
self.assertIn("DataType", schema["$defs"])
|
||||||
|
|
||||||
|
# Should have all three tools
|
||||||
|
tool_names = [
|
||||||
|
item["properties"]["name"]["enum"][0] for item in schema["items"]["anyOf"]
|
||||||
|
]
|
||||||
|
self.assertEqual(len(tool_names), 3)
|
||||||
|
self.assertIn("simple_tool", tool_names)
|
||||||
|
self.assertIn("complex_tool", tool_names)
|
||||||
|
self.assertIn("another_simple_tool", tool_names)
|
||||||
|
|
||||||
|
def test_tools_with_defs_but_no_refs(self):
|
||||||
|
"""Test tools with $defs but no $ref usage"""
|
||||||
|
tools_with_unused_defs = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="unused_defs_tool",
|
||||||
|
description="Tool with $defs but no $ref usage",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
"$defs": {
|
||||||
|
"UnusedType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_get_tool_schema_defs(tools_with_unused_defs)
|
||||||
|
except ValueError as e:
|
||||||
|
self.fail(f"Should not raise ValueError, but got: {e}")
|
||||||
|
|
||||||
|
schema = get_json_schema_constraint(tools_with_unused_defs, "required")
|
||||||
|
self.assertIsNotNone(schema)
|
||||||
|
jsonschema.Draft202012Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Should still include $defs even if not referenced
|
||||||
|
self.assertIn("$defs", schema)
|
||||||
|
self.assertIn("UnusedType", schema["$defs"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -354,7 +354,7 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
{"type": "function", "function": {"name": "get_weather"}},
|
{"type": "function", "function": {"name": "get_weather"}},
|
||||||
]
|
]
|
||||||
|
|
||||||
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
tool_calls, remaining_text, finish_reason = self.chat._process_tool_calls(
|
||||||
text="<|tool_calls_section_begin|>...",
|
text="<|tool_calls_section_begin|>...",
|
||||||
tools=tools,
|
tools=tools,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
|||||||
@@ -73,11 +73,11 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {
|
"a": {
|
||||||
"type": "int",
|
"type": "integer",
|
||||||
"description": "A number",
|
"description": "A number",
|
||||||
},
|
},
|
||||||
"b": {
|
"b": {
|
||||||
"type": "int",
|
"type": "integer",
|
||||||
"description": "A number",
|
"description": "A number",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -128,11 +128,11 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {
|
"a": {
|
||||||
"type": "int",
|
"type": "integer",
|
||||||
"description": "A number",
|
"description": "A number",
|
||||||
},
|
},
|
||||||
"b": {
|
"b": {
|
||||||
"type": "int",
|
"type": "integer",
|
||||||
"description": "A number",
|
"description": "A number",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -343,6 +343,142 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
|
|
||||||
self.assertEqual(found_name, "get_weather")
|
self.assertEqual(found_name, "get_weather")
|
||||||
|
|
||||||
|
def test_required_streaming_arguments_chunks_json(self):
|
||||||
|
"""In streaming required mode, complete tool call arguments should be valid JSON when all chunks are combined"""
|
||||||
|
tools = self.get_test_tools()
|
||||||
|
messages = self.get_test_messages()
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="required",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all tool call chunks and reconstruct complete tool calls
|
||||||
|
tool_calls_by_index = {}
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
for tool_call_delta in chunk.choices[0].delta.tool_calls:
|
||||||
|
tool_index = tool_call_delta.index
|
||||||
|
|
||||||
|
# Initialize tool call if not seen before
|
||||||
|
if tool_index not in tool_calls_by_index:
|
||||||
|
tool_calls_by_index[tool_index] = {
|
||||||
|
"id": tool_call_delta.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "", "arguments": ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update function name if present (first chunk)
|
||||||
|
if tool_call_delta.function and tool_call_delta.function.name:
|
||||||
|
tool_calls_by_index[tool_index]["function"][
|
||||||
|
"name"
|
||||||
|
] = tool_call_delta.function.name
|
||||||
|
|
||||||
|
# Accumulate arguments (all chunks)
|
||||||
|
if tool_call_delta.function and tool_call_delta.function.arguments:
|
||||||
|
tool_calls_by_index[tool_index]["function"][
|
||||||
|
"arguments"
|
||||||
|
] += tool_call_delta.function.arguments
|
||||||
|
|
||||||
|
self.assertGreater(len(tool_calls_by_index), 0)
|
||||||
|
|
||||||
|
# Validate that complete tool calls have valid JSON arguments
|
||||||
|
for tool_call in tool_calls_by_index.values():
|
||||||
|
self.assertIsNotNone(tool_call["function"]["name"])
|
||||||
|
self.assertIsNotNone(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
|
# The complete arguments should be valid JSON
|
||||||
|
try:
|
||||||
|
args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
self.assertIsInstance(args, dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.fail(
|
||||||
|
f"Invalid JSON in complete tool call arguments: {tool_call['function']['arguments']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_complex_parameters_required_non_streaming(self):
|
||||||
|
"""Validate complex nested parameter schemas in non-streaming required mode"""
|
||||||
|
complex_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "analyze_data",
|
||||||
|
"description": "Analyze complex data structures",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"metrics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"threshold": {"type": "number"},
|
||||||
|
"enabled": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["metrics"],
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"value": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Analyze some data with metrics and configuration",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=complex_tools,
|
||||||
|
tool_choice="required",
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
self.assertIsNotNone(tool_calls)
|
||||||
|
self.assertGreater(len(tool_calls), 0)
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
self.assertEqual(tool_call.function.name, "analyze_data")
|
||||||
|
try:
|
||||||
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
self.assertIsInstance(args, dict)
|
||||||
|
self.assertIn("data", args)
|
||||||
|
self.assertIsInstance(args["data"], dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.fail(
|
||||||
|
f"Invalid JSON in complex tool call arguments: {tool_call.function.arguments}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_multi_tool_scenario_auto(self):
|
def test_multi_tool_scenario_auto(self):
|
||||||
"""Test multi-tool scenario with tool_choice='auto'"""
|
"""Test multi-tool scenario with tool_choice='auto'"""
|
||||||
tools = self.get_travel_tools()
|
tools = self.get_travel_tools()
|
||||||
@@ -408,6 +544,10 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
available_names = [tool["function"]["name"] for tool in tools]
|
available_names = [tool["function"]["name"] for tool in tools]
|
||||||
expected_functions = {"get_weather", "get_tourist_attractions"}
|
expected_functions = {"get_weather", "get_tourist_attractions"}
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
self.assertIsNotNone(tool_call.function.name)
|
||||||
|
self.assertIsNotNone(tool_call.function.arguments)
|
||||||
|
|
||||||
if self._is_flaky_test():
|
if self._is_flaky_test():
|
||||||
# For flaky tests, just ensure basic functionality works
|
# For flaky tests, just ensure basic functionality works
|
||||||
self.assertGreater(
|
self.assertGreater(
|
||||||
@@ -432,22 +572,15 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
|
|
||||||
def test_error_handling_invalid_tool_choice(self):
|
def test_error_handling_invalid_tool_choice(self):
|
||||||
"""Test error handling for invalid tool_choice"""
|
"""Test error handling for invalid tool_choice"""
|
||||||
import logging
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
tools = self.get_test_tools()
|
tools = self.get_test_tools()
|
||||||
messages = self.get_test_messages()
|
messages = self.get_test_messages()
|
||||||
|
|
||||||
# Test with invalid function name
|
# Test with invalid function name
|
||||||
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
|
tool_choice = {"type": "function", "function": {"name": "nonexistent_function"}}
|
||||||
|
|
||||||
# The behavior could be either:
|
# Expect a 400 BadRequestError to be raised for invalid tool_choice
|
||||||
# 1. Log a warning and continue (if fallback is implemented)
|
with self.assertRaises(openai.BadRequestError) as context:
|
||||||
# 2. Raise an exception (if strict validation is implemented)
|
self.client.chat.completions.create(
|
||||||
|
|
||||||
# First try to capture any logging that might happen
|
|
||||||
with patch("logging.warning") as mock_warning:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
@@ -456,11 +589,173 @@ class TestToolChoiceLlama32(CustomTestCase):
|
|||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(response.choices[0].message)
|
# Verify the error message contains the expected text
|
||||||
|
self.assertIn(
|
||||||
|
"Tool 'nonexistent_function' not found in tools list",
|
||||||
|
str(context.exception),
|
||||||
|
)
|
||||||
|
|
||||||
if mock_warning.called:
|
def test_invalid_tool_missing_name(self):
|
||||||
warning_message = mock_warning.call_args[0][0]
|
"""Test what happens when user doesn't provide a tool name in request"""
|
||||||
self.assertIn("nonexistent_function", warning_message)
|
# Test with malformed JSON in tool parameters - missing required "name" field
|
||||||
|
invalid_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
# Missing required "name" field
|
||||||
|
"description": "Test function with invalid schema",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"test_field": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Test field",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["test_field"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Test the function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should raise BadRequestError due to missing required 'name' field
|
||||||
|
with self.assertRaises(openai.BadRequestError) as context:
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=invalid_tools,
|
||||||
|
tool_choice="required",
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the error message indicates missing name field
|
||||||
|
error_msg = str(context.exception).lower()
|
||||||
|
self.assertIn("name", error_msg)
|
||||||
|
|
||||||
|
def test_invalid_json_schema_in_tool(self):
|
||||||
|
"""Test what happens when tool function has invalid JSON schema"""
|
||||||
|
invalid_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "test_function",
|
||||||
|
"description": "Test function with invalid JSON schema",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"invalid_field": {
|
||||||
|
"type": "unknown_type", # Invalid type
|
||||||
|
"description": "This field has an invalid type",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["invalid_field"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Test the function",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should raise BadRequestError due to invalid JSON schema in tool parameters
|
||||||
|
with self.assertRaises(openai.BadRequestError) as context:
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=invalid_tools,
|
||||||
|
tool_choice="required",
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the error message indicates invalid JSON schema for parameters field
|
||||||
|
error_msg = str(context.exception).lower()
|
||||||
|
self.assertIn("invalid 'parameters' schema", error_msg)
|
||||||
|
|
||||||
|
def test_conflicting_defs_required_tool_choice(self):
|
||||||
|
"""Test that conflicting $defs with required tool_choice returns 400 error"""
|
||||||
|
conflicting_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tool1",
|
||||||
|
"description": "Tool 1 with conflicting $defs",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"$ref": "#/$defs/DataType"},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
"$defs": {
|
||||||
|
"DataType": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"value": {"type": "string"}},
|
||||||
|
"required": ["value"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tool2",
|
||||||
|
"description": "Tool 2 with conflicting $defs",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"$ref": "#/$defs/DataType"},
|
||||||
|
},
|
||||||
|
"required": ["data"],
|
||||||
|
"$defs": {
|
||||||
|
"DataType": { # Different definition for DataType
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"value": {"type": "number"}},
|
||||||
|
"required": ["value"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Test the conflicting tools",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should raise BadRequestError due to conflicting $defs
|
||||||
|
with self.assertRaises(openai.BadRequestError) as context:
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=conflicting_tools,
|
||||||
|
tool_choice="required",
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the error message indicates conflicting tool definitions
|
||||||
|
error_msg = str(context.exception).lower()
|
||||||
|
self.assertIn("multiple schemas", error_msg)
|
||||||
|
self.assertIn("not supported", error_msg)
|
||||||
|
|
||||||
|
|
||||||
class TestToolChoiceQwen25(TestToolChoiceLlama32):
|
class TestToolChoiceQwen25(TestToolChoiceLlama32):
|
||||||
@@ -516,6 +811,16 @@ class TestToolChoiceMistral(TestToolChoiceLlama32):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
cls.tokenizer = get_tokenizer(cls.model)
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
|
||||||
|
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
|
||||||
|
def test_multi_tool_scenario_required(self):
|
||||||
|
"""Test multi-tool scenario with tool_choice='required'"""
|
||||||
|
super().test_multi_tool_scenario_required()
|
||||||
|
|
||||||
|
@unittest.skip("Fails due to whitespace issue with Mistral - skipping")
|
||||||
|
def test_complex_parameters_required_non_streaming(self):
|
||||||
|
"""Validate complex nested parameter schemas in non-streaming required mode"""
|
||||||
|
super().test_complex_parameters_required_non_streaming()
|
||||||
|
|
||||||
|
|
||||||
# Skip for ci test
|
# Skip for ci test
|
||||||
# class TestToolChoiceGLM45(TestToolChoiceLlama32):
|
# class TestToolChoiceGLM45(TestToolChoiceLlama32):
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ suites = {
|
|||||||
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
||||||
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
||||||
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
||||||
|
TestFile("function_call/test_json_schema_constraint.py", 30),
|
||||||
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
||||||
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
||||||
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
||||||
@@ -205,6 +206,7 @@ suite_amd = {
|
|||||||
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
||||||
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
||||||
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
||||||
|
TestFile("function_call/test_json_schema_constraint.py", 30),
|
||||||
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
||||||
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
||||||
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ from xgrammar import GrammarCompiler, TokenizerInfo
|
|||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
from sglang.srt.entrypoints.openai.protocol import Function, Tool
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||||
|
from sglang.srt.function_call.core_types import StreamingParseResult
|
||||||
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
|
||||||
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
|
||||||
|
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
||||||
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
|
||||||
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
from sglang.srt.function_call.llama32_detector import Llama32Detector
|
||||||
from sglang.srt.function_call.mistral_detector import MistralDetector
|
from sglang.srt.function_call.mistral_detector import MistralDetector
|
||||||
@@ -2190,5 +2192,322 @@ class TestGlm4MoeDetector(unittest.TestCase):
|
|||||||
self.assertEqual(self.detector._buffer, "")
|
self.assertEqual(self.detector._buffer, "")
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonArrayParser(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create sample tools for testing
|
||||||
|
self.tools = [
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Location to get weather for",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Temperature unit",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="search",
|
||||||
|
description="Search for information",
|
||||||
|
parameters={
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.detector = JsonArrayParser()
|
||||||
|
|
||||||
|
def test_json_detector_ebnf(self):
|
||||||
|
"""Test that the JsonArrayParser returns NotImplementedError for EBNF."""
|
||||||
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
|
self.detector.build_ebnf(self.tools)
|
||||||
|
self.assertIn(
|
||||||
|
"EBNF generation is not supported for JSON schema constraints",
|
||||||
|
str(context.exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_parse_streaming_increment_malformed_json(self):
|
||||||
|
"""Test parsing with malformed JSON"""
|
||||||
|
# Test with malformed JSON
|
||||||
|
text = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
# Should not crash and return a valid result
|
||||||
|
self.assertIsInstance(result, StreamingParseResult)
|
||||||
|
|
||||||
|
text = "[{}}}]"
|
||||||
|
result = self.detector.parse_streaming_increment(text, self.tools)
|
||||||
|
|
||||||
|
self.assertIsInstance(result, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_parse_streaming_increment_empty_input(self):
|
||||||
|
"""Test parsing with empty input"""
|
||||||
|
result = self.detector.parse_streaming_increment("", self.tools)
|
||||||
|
self.assertEqual(len(result.calls), 0)
|
||||||
|
self.assertEqual(result.normal_text, "")
|
||||||
|
|
||||||
|
def test_parse_streaming_increment_whitespace_handling(self):
|
||||||
|
"""Test parsing with various whitespace scenarios"""
|
||||||
|
# Test with leading/trailing whitespace split across chunks
|
||||||
|
chunk1 = ' [{"name": "get_weather", "parameters": '
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = '{"location": "Tokyo"}}] '
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
|
||||||
|
# The base class should handle this
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_parse_streaming_increment_nested_objects(self):
|
||||||
|
"""Test parsing with nested JSON objects"""
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo", '
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = '"nested": {"key": "value"}}}]'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
|
||||||
|
# The base class should handle this
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_json_parsing_with_commas(self):
|
||||||
|
"""Test that JSON parsing works correctly with comma separators"""
|
||||||
|
# Stream two complete objects, at least 2 chunks per tool call
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = 'yo"}},'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
chunk3 = '{"name": "get_weather", "parameters": {"location": "Par'
|
||||||
|
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||||
|
self.assertIsInstance(result3, StreamingParseResult)
|
||||||
|
chunk4 = 'is"}}]'
|
||||||
|
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||||
|
self.assertIsInstance(result4, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result4.calls), 0, "Should parse tool calls from text with separators"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_braces_in_strings(self):
|
||||||
|
"""Test that JSON with } characters inside strings works correctly"""
|
||||||
|
# Test case: JSON array with } inside string values - streamed across chunks
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "has } inside"'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = "}}"
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result2.calls), 0, "Should parse tool call with } in string"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with separator (streaming in progress)
|
||||||
|
chunk3 = '[{"name": "get_weather", "parameters": {"location": "has } inside"}'
|
||||||
|
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||||
|
self.assertIsInstance(result3, StreamingParseResult)
|
||||||
|
chunk4 = "},"
|
||||||
|
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||||
|
self.assertIsInstance(result4, StreamingParseResult)
|
||||||
|
chunk5 = '{"name": "get_weather"'
|
||||||
|
result5 = self.detector.parse_streaming_increment(chunk5, self.tools)
|
||||||
|
self.assertIsInstance(result5, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result5.calls),
|
||||||
|
0,
|
||||||
|
"Should parse tool calls with separator and } in string",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_separator_in_same_chunk(self):
|
||||||
|
"""Test that separator already present in chunk works correctly"""
|
||||||
|
# Test case: separator already in the chunk (streaming in progress) with 2+ chunks per tool call
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = '}},{"name": "get_weather"'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result2.calls),
|
||||||
|
0,
|
||||||
|
"Should parse tool calls with separator in same chunk",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_separator_in_separate_chunk(self):
|
||||||
|
"""Test that separator in separate chunk works correctly"""
|
||||||
|
# Test case: separator in separate chunk - this tests streaming behavior
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}}'
|
||||||
|
chunk2 = ","
|
||||||
|
chunk3 = '{"name": "get_weather", "parameters": {"location": "Paris"}}'
|
||||||
|
|
||||||
|
# Process first chunk
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
|
||||||
|
# Process separator chunk
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
# Process second chunk (streaming in progress)
|
||||||
|
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||||
|
self.assertIsInstance(result3, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_incomplete_json_across_chunks(self):
|
||||||
|
"""Test that incomplete JSON across chunks works correctly"""
|
||||||
|
# Test case: incomplete JSON across chunks - this tests streaming behavior
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"'
|
||||||
|
chunk2 = '}},{"name": "get_weather"'
|
||||||
|
|
||||||
|
# Process first chunk (incomplete)
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
|
||||||
|
# Process second chunk (completes first object and starts second, streaming in progress)
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_malformed_json_recovery(self):
|
||||||
|
"""Test that malformed JSON recovers gracefully"""
|
||||||
|
# Test with malformed JSON - should handle gracefully
|
||||||
|
malformed_text = (
|
||||||
|
'[{"name": "get_weather", "parameters": {"location": "unclosed string'
|
||||||
|
)
|
||||||
|
|
||||||
|
result1 = self.detector.parse_streaming_increment(malformed_text, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
|
||||||
|
# Test valid JSON after malformed - streamed across 2 chunks (streaming in progress)
|
||||||
|
valid_chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||||
|
result2 = self.detector.parse_streaming_increment(valid_chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
valid_chunk2 = 'yo"}}'
|
||||||
|
result3 = self.detector.parse_streaming_increment(valid_chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result3, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_nested_objects_with_commas(self):
|
||||||
|
"""Test that nested objects with commas inside work correctly"""
|
||||||
|
# Test with nested objects that have commas - should work with json.loads()
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tok'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = 'yo", "unit": "celsius"}}'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result2.calls), 0, "Should parse tool call with nested objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_empty_objects(self):
|
||||||
|
"""Test that empty objects work correctly"""
|
||||||
|
# Test with empty objects - should work with json.loads()
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": '
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = "{}}"
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_whitespace_handling(self):
|
||||||
|
"""Test that various whitespace scenarios work correctly"""
|
||||||
|
# Test with various whitespace patterns - should work with json.loads()
|
||||||
|
chunk1 = ' \n\n [{"name": "get_weather", "parameters": '
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = '{"location": "Tokyo"}}'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
def test_multiple_commas_in_chunk(self):
|
||||||
|
"""Test that multiple commas in a single chunk work correctly"""
|
||||||
|
# Stream multiple tool calls ensuring at least 2 chunks per complete tool call
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "To'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = 'kyo"}},'
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
|
||||||
|
chunk3 = '{"name": "get_weather", "parameters": {"location": "Pa'
|
||||||
|
result3 = self.detector.parse_streaming_increment(chunk3, self.tools)
|
||||||
|
self.assertIsInstance(result3, StreamingParseResult)
|
||||||
|
chunk4 = 'ris"}},'
|
||||||
|
result4 = self.detector.parse_streaming_increment(chunk4, self.tools)
|
||||||
|
self.assertIsInstance(result4, StreamingParseResult)
|
||||||
|
|
||||||
|
chunk5 = '{"name": "get_weather"'
|
||||||
|
result5 = self.detector.parse_streaming_increment(chunk5, self.tools)
|
||||||
|
self.assertIsInstance(result5, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result5.calls), 0, "Should parse tool calls with multiple commas"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_complete_tool_call_with_trailing_comma(self):
|
||||||
|
"""Test that complete tool call with trailing comma parses correctly"""
|
||||||
|
# Test case: complete tool call followed by comma at end of chunk (split across 2 chunks)
|
||||||
|
chunk1 = '[{"name": "get_weather", "parameters": {"location": "Tokyo"}'
|
||||||
|
result1 = self.detector.parse_streaming_increment(chunk1, self.tools)
|
||||||
|
self.assertIsInstance(result1, StreamingParseResult)
|
||||||
|
chunk2 = "}, "
|
||||||
|
result2 = self.detector.parse_streaming_increment(chunk2, self.tools)
|
||||||
|
self.assertIsInstance(result2, StreamingParseResult)
|
||||||
|
self.assertGreater(len(result2.calls), 0, "Should parse complete tool call")
|
||||||
|
|
||||||
|
# Test that next chunk with opening brace gets the separator prepended
|
||||||
|
next_chunk = '{"name": "get_weather", "parameters": {"location": "Paris"}}'
|
||||||
|
result_next = self.detector.parse_streaming_increment(next_chunk, self.tools)
|
||||||
|
self.assertIsInstance(result_next, StreamingParseResult)
|
||||||
|
self.assertGreater(
|
||||||
|
len(result_next.calls), 0, "Should parse subsequent tool call"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_three_tool_calls_separate_chunks_with_commas(self):
|
||||||
|
"""Test parsing 3 tool calls in separate chunks with commas at the end"""
|
||||||
|
# First tool call: 2 chunks
|
||||||
|
chunk1_1 = '[{"name": "get_weather", "parameters": '
|
||||||
|
result1_1 = self.detector.parse_streaming_increment(chunk1_1, self.tools)
|
||||||
|
chunk1_2 = '{"location": "Tokyo"}},'
|
||||||
|
result1_2 = self.detector.parse_streaming_increment(chunk1_2, self.tools)
|
||||||
|
self.assertIsInstance(result1_2, StreamingParseResult)
|
||||||
|
self.assertGreater(len(result1_2.calls), 0, "Should parse first tool call")
|
||||||
|
|
||||||
|
# Second tool call: 2 chunks
|
||||||
|
chunk2_1 = '{"name": "search", "parameters": '
|
||||||
|
result2_1 = self.detector.parse_streaming_increment(chunk2_1, self.tools)
|
||||||
|
chunk2_2 = '{"query": "restaurants"}},'
|
||||||
|
result2_2 = self.detector.parse_streaming_increment(chunk2_2, self.tools)
|
||||||
|
self.assertIsInstance(result2_2, StreamingParseResult)
|
||||||
|
self.assertGreater(len(result2_2.calls), 0, "Should parse second tool call")
|
||||||
|
|
||||||
|
# Third tool call: 2 chunks
|
||||||
|
chunk3_1 = '{"name": "get_weather", "parameters": '
|
||||||
|
result3_1 = self.detector.parse_streaming_increment(chunk3_1, self.tools)
|
||||||
|
chunk3_2 = '{"location": "Paris"}}]'
|
||||||
|
result3_2 = self.detector.parse_streaming_increment(chunk3_2, self.tools)
|
||||||
|
self.assertIsInstance(result3_2, StreamingParseResult)
|
||||||
|
self.assertGreater(len(result3_2.calls), 0, "Should parse third tool call")
|
||||||
|
# Verify all tool calls were parsed correctly
|
||||||
|
total_calls = len(result1_2.calls) + len(result2_2.calls) + len(result3_2.calls)
|
||||||
|
self.assertEqual(total_calls, 3, "Should have parsed exactly 3 tool calls")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user