Use jsonschema to constrain required or specific tool choice (#10550)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
@@ -392,7 +392,7 @@ class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
name: str
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
@@ -943,6 +943,16 @@ class MessageProcessingResult:
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
|
||||
class ToolCallProcessingResult(NamedTuple):
|
||||
"""Result of processing tool calls in a response."""
|
||||
|
||||
tool_calls: Optional[
|
||||
List[Any]
|
||||
] # List of ToolCall objects or None if parsing failed
|
||||
remaining_text: str # Text remaining after parsing tool calls
|
||||
finish_reason: Dict[str, Any] # Updated finish reason dictionary
|
||||
|
||||
|
||||
class ResponseReasoningTextContent(BaseModel):
|
||||
text: str
|
||||
type: Literal["reasoning_text"] = "reasoning_text"
|
||||
|
||||
@@ -62,6 +62,12 @@ class OpenAIServingBase(ABC):
|
||||
return self.create_error_response(
|
||||
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(
|
||||
message=str(e),
|
||||
err_type="BadRequest",
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in request: {e}")
|
||||
return self.create_error_response(
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
from jsonschema import Draft202012Validator, SchemaError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@@ -25,6 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
LogProbs,
|
||||
MessageProcessingResult,
|
||||
ToolCall,
|
||||
ToolCallProcessingResult,
|
||||
ToolChoice,
|
||||
TopLogprob,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
@@ -35,6 +38,8 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
)
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.parser.conversation import generate_chat_conv
|
||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||
@@ -75,6 +80,23 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
return "Tools cannot be empty if tool choice is set to required."
|
||||
|
||||
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
|
||||
if not request.tools:
|
||||
return "Tools cannot be empty if tool choice is set to a specific tool."
|
||||
tool_name = request.tool_choice.function.name
|
||||
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
|
||||
if not tool_exists:
|
||||
return f"Tool '{tool_name}' not found in tools list."
|
||||
|
||||
# Validate tool definitions
|
||||
for i, tool in enumerate(request.tools or []):
|
||||
if tool.function.parameters is None:
|
||||
continue
|
||||
try:
|
||||
Draft202012Validator.check_schema(tool.function.parameters)
|
||||
except SchemaError as e:
|
||||
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
|
||||
|
||||
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
||||
server_context_length = self.tokenizer_manager.server_args.context_length
|
||||
if (
|
||||
@@ -190,6 +212,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
# Handle JSON schema constraint directly for required or named tool choice
|
||||
if request.tool_choice == "required" or isinstance(
|
||||
request.tool_choice, ToolChoice
|
||||
):
|
||||
json_schema = get_json_schema_constraint(
|
||||
request.tools, request.tool_choice
|
||||
)
|
||||
tool_call_constraint = ("json_schema", json_schema)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
@@ -437,6 +467,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
return sampling_params
|
||||
@@ -752,7 +786,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text, request.tools, finish_reason, history_tool_calls_cnt
|
||||
text,
|
||||
request.tools,
|
||||
finish_reason,
|
||||
request.tool_choice,
|
||||
history_tool_calls_cnt,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
@@ -867,9 +905,51 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
text: str,
|
||||
tools: List[Any],
|
||||
finish_reason: Dict[str, Any],
|
||||
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
||||
history_tool_calls_cnt: int = 0,
|
||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||
) -> ToolCallProcessingResult:
|
||||
"""Process tool calls in the response"""
|
||||
|
||||
# Handle required or named tool choice
|
||||
if tool_choice == "required" or (
|
||||
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
|
||||
):
|
||||
# Set finish reason to tool_calls since we're processing tool calls
|
||||
if finish_reason["type"] == "stop":
|
||||
finish_reason["type"] = "tool_calls"
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
# For required tool choice, we expect a JSON array of tool calls
|
||||
tool_call_data = json.loads(text)
|
||||
tool_calls = []
|
||||
for i, tool in enumerate(tool_call_data):
|
||||
# Create a ToolCallItem from the JSON data
|
||||
call_info = ToolCallItem(
|
||||
tool_index=i, # Use the loop index as tool_index
|
||||
name=tool["name"],
|
||||
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
|
||||
)
|
||||
tool_id = self._process_tool_call_id(
|
||||
call_info, history_tool_calls_cnt
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
index=i,
|
||||
function=FunctionResponse(
|
||||
name=tool["name"],
|
||||
arguments=json.dumps(
|
||||
tool["parameters"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolCallProcessingResult(tool_calls, "", finish_reason)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
|
||||
# Use parser since output is not constrained by JSON schema
|
||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||
if parser.has_tool_call(text):
|
||||
if finish_reason["type"] == "stop":
|
||||
@@ -891,13 +971,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
),
|
||||
)
|
||||
)
|
||||
return tool_calls, text, finish_reason
|
||||
return ToolCallProcessingResult(tool_calls, text, finish_reason)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
# Return error but don't fail the whole request
|
||||
return None, text, finish_reason
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
|
||||
return None, text, finish_reason
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
|
||||
def _process_streaming_logprobs(
|
||||
self, content: Dict[str, Any], n_prev_token: int
|
||||
@@ -990,13 +1070,25 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
"""Process tool calls in streaming response"""
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=self.tool_call_parser,
|
||||
)
|
||||
# Use JSON detector directly for required or named tool choice
|
||||
if request.tool_choice == "required" or isinstance(
|
||||
request.tool_choice, ToolChoice
|
||||
):
|
||||
parser_dict[index] = JsonArrayParser()
|
||||
else:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=self.tool_call_parser,
|
||||
)
|
||||
|
||||
parser = parser_dict[index]
|
||||
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
# Handle both FunctionCallParser and JsonArrayParser
|
||||
if isinstance(parser, JsonArrayParser):
|
||||
result = parser.parse_streaming_increment(delta, request.tools)
|
||||
normal_text, calls = result.normal_text, result.calls
|
||||
else:
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# Yield normal text
|
||||
if normal_text:
|
||||
@@ -1055,7 +1147,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
def _check_for_unstreamed_tool_args(
|
||||
self,
|
||||
parser: FunctionCallParser,
|
||||
parser: Union[FunctionCallParser, JsonArrayParser],
|
||||
content: Dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
index: int,
|
||||
@@ -1065,30 +1157,31 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
when generation finishes. This ensures tool calls are properly completed
|
||||
even if the model generates the final arguments in the last chunk.
|
||||
"""
|
||||
# Only check if we have tool calls and the parser has tracked data
|
||||
# Get the detector - either from FunctionCallParser or directly if json detector
|
||||
detector = parser.detector if hasattr(parser, "detector") else parser
|
||||
|
||||
# Only check if we have tool calls and the detector has tracked data
|
||||
if (
|
||||
not hasattr(parser.detector, "prev_tool_call_arr")
|
||||
or not parser.detector.prev_tool_call_arr
|
||||
not hasattr(detector, "prev_tool_call_arr")
|
||||
or not detector.prev_tool_call_arr
|
||||
):
|
||||
return None
|
||||
|
||||
if (
|
||||
not hasattr(parser.detector, "streamed_args_for_tool")
|
||||
or not parser.detector.streamed_args_for_tool
|
||||
not hasattr(detector, "streamed_args_for_tool")
|
||||
or not detector.streamed_args_for_tool
|
||||
):
|
||||
return None
|
||||
|
||||
# Get the last tool call that was being processed
|
||||
tool_index = len(parser.detector.prev_tool_call_arr) - 1
|
||||
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
|
||||
tool_index = len(detector.prev_tool_call_arr) - 1
|
||||
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
|
||||
return None
|
||||
|
||||
# Get expected vs actual arguments
|
||||
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
|
||||
"arguments", {}
|
||||
)
|
||||
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
|
||||
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
||||
actual_call = parser.detector.streamed_args_for_tool[tool_index]
|
||||
actual_call = detector.streamed_args_for_tool[tool_index]
|
||||
|
||||
# Check if there are remaining arguments to send
|
||||
remaining_call = (
|
||||
|
||||
@@ -20,6 +20,7 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -178,8 +179,8 @@ class FunctionCallParser:
|
||||
strict_tag = self.get_structure_tag()
|
||||
return ("structural_tag", strict_tag)
|
||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||
ebnf = self.get_ebnf(tool_choice)
|
||||
return ("ebnf", ebnf) if ebnf is not None else None
|
||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||
return ("json_schema", json_schema)
|
||||
|
||||
def get_ebnf(
|
||||
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||
|
||||
63
python/sglang/srt/function_call/json_array_parser.py
Normal file
63
python/sglang/srt/function_call/json_array_parser.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import StreamingParseResult
|
||||
|
||||
|
||||
class JsonArrayParser(BaseFormatDetector):
|
||||
"""
|
||||
Parser for JSON array tool calls when JSON schema constraints are active.
|
||||
|
||||
This parser is used when tool_choice="required" or a specific tool is named,
|
||||
bypassing model-specific parsers in favor of direct JSON array parsing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Configure for JSON array parsing
|
||||
self.bot_token = "["
|
||||
self.eot_token = "]"
|
||||
self.tool_call_separator = ","
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains a JSON tool call (array or single object).
|
||||
"""
|
||||
return "[" in text or "{" in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parse JSON tool calls using the base class implementation.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Detect and parse not supported for JSON schema constraints."
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build an EBNF grammar for constrained generation.
|
||||
This is not used for JSON schema constraints as they are handled
|
||||
by the constraint backends directly.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"EBNF generation is not supported for JSON schema constraints."
|
||||
)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
"""
|
||||
return super().parse_streaming_increment(new_text, tools)
|
||||
|
||||
def structure_info(self) -> callable:
|
||||
"""
|
||||
Return a function that creates StructureInfo for constrained generation.
|
||||
This is not used for JSON schema constraints as they are handled
|
||||
by the constraint backends directly.
|
||||
"""
|
||||
raise NotImplementedError("structure_info not used for JSON schema constraints")
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Tuple
|
||||
from json.decoder import WHITESPACE
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
@@ -37,10 +40,12 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
"""
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
except (JSONDecodeError, IndexError) as e:
|
||||
msg = getattr(e, "msg", str(e))
|
||||
if "Extra data" in msg or "pop from empty list" in msg:
|
||||
start = WHITESPACE.match(input_str, 0).end()
|
||||
obj, end = JSONDecoder().raw_decode(input_str, start)
|
||||
return obj, end
|
||||
raise
|
||||
|
||||
|
||||
@@ -50,3 +55,89 @@ def _is_complete_json(input_str: str) -> bool:
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
|
||||
"""
|
||||
Get consolidated $defs from all tools, validating for conflicts.
|
||||
|
||||
Args:
|
||||
tools: List of tools to process
|
||||
|
||||
Returns:
|
||||
Dictionary of consolidated $defs from all tools
|
||||
|
||||
Raises:
|
||||
ValueError: If conflicting $defs are found
|
||||
"""
|
||||
all_defs = {}
|
||||
for tool in tools:
|
||||
if tool.function.parameters is None:
|
||||
continue
|
||||
defs = tool.function.parameters.get("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has "
|
||||
"multiple schemas, which is not "
|
||||
"supported."
|
||||
)
|
||||
else:
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_tool_schema(tool: Tool) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [tool.function.name]},
|
||||
"parameters": (
|
||||
tool.function.parameters
|
||||
if tool.function.parameters
|
||||
else {"type": "object", "properties": {}}
|
||||
),
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def get_json_schema_constraint(
|
||||
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the JSON schema constraint for the specified tool choice.
|
||||
|
||||
Args:
|
||||
tool_choice: The tool choice specification
|
||||
|
||||
Returns:
|
||||
JSON schema dict, or None if no valid tools found
|
||||
"""
|
||||
|
||||
if isinstance(tool_choice, ToolChoice):
|
||||
# For specific function choice, return the user's parameters schema directly
|
||||
fn_name = tool_choice.function.name
|
||||
for tool in tools:
|
||||
if tool.function.name == fn_name:
|
||||
return {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 1,
|
||||
"items": _get_tool_schema(tool),
|
||||
}
|
||||
return None
|
||||
elif tool_choice == "required":
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user