Use jsonschema to constrain required or specific tool choice (#10550)

This commit is contained in:
Tejesh Anand
2025-09-27 10:18:50 -07:00
committed by GitHub
parent 9c339d6b47
commit 8cc27fdc46
12 changed files with 1558 additions and 50 deletions

View File

@@ -16,7 +16,7 @@
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypeAlias, Union
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
@@ -392,7 +392,7 @@ class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
name: str
parameters: Optional[object] = None
strict: bool = False
@@ -943,6 +943,16 @@ class MessageProcessingResult:
tool_call_constraint: Optional[Any] = None
class ToolCallProcessingResult(NamedTuple):
"""Result of processing tool calls in a response."""
tool_calls: Optional[
List[Any]
] # List of ToolCall objects or None if parsing failed
remaining_text: str # Text remaining after parsing tool calls
finish_reason: Dict[str, Any] # Updated finish reason dictionary
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"

View File

@@ -62,6 +62,12 @@ class OpenAIServingBase(ABC):
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
except ValueError as e:
return self.create_error_response(
message=str(e),
err_type="BadRequest",
status_code=400,
)
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -25,6 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs,
MessageProcessingResult,
ToolCall,
ToolCallProcessingResult,
ToolChoice,
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
@@ -35,6 +38,8 @@ from sglang.srt.entrypoints.openai.utils import (
)
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
@@ -75,6 +80,23 @@ class OpenAIServingChat(OpenAIServingBase):
):
return "Tools cannot be empty if tool choice is set to required."
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
if not request.tools:
return "Tools cannot be empty if tool choice is set to a specific tool."
tool_name = request.tool_choice.function.name
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
if not tool_exists:
return f"Tool '{tool_name}' not found in tools list."
# Validate tool definitions
for i, tool in enumerate(request.tools or []):
if tool.function.parameters is None:
continue
try:
Draft202012Validator.check_schema(tool.function.parameters)
except SchemaError as e:
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length
if (
@@ -190,6 +212,14 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Handle JSON schema constraint directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
json_schema = get_json_schema_constraint(
request.tools, request.tool_choice
)
tool_call_constraint = ("json_schema", json_schema)
# Use chat template
if self.template_manager.chat_template_name is None:
@@ -437,6 +467,10 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
@@ -752,7 +786,11 @@ class OpenAIServingChat(OpenAIServingBase):
):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, finish_reason, history_tool_calls_cnt
text,
request.tools,
finish_reason,
request.tool_choice,
history_tool_calls_cnt,
)
choice_data = ChatCompletionResponseChoice(
@@ -867,9 +905,51 @@ class OpenAIServingChat(OpenAIServingBase):
text: str,
tools: List[Any],
finish_reason: Dict[str, Any],
tool_choice: Optional[Union[str, ToolChoice]] = None,
history_tool_calls_cnt: int = 0,
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
) -> ToolCallProcessingResult:
"""Process tool calls in the response"""
# Handle required or named tool choice
if tool_choice == "required" or (
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
):
# Set finish reason to tool_calls since we're processing tool calls
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text)
tool_calls = []
for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data
call_info = ToolCallItem(
tool_index=i, # Use the loop index as tool_index
name=tool["name"],
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
)
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
tool_calls.append(
ToolCall(
id=tool_id,
index=i,
function=FunctionResponse(
name=tool["name"],
arguments=json.dumps(
tool["parameters"], ensure_ascii=False
),
),
)
)
return ToolCallProcessingResult(tool_calls, "", finish_reason)
except json.JSONDecodeError as e:
logger.error(f"Tool call parsing error: {e}")
return ToolCallProcessingResult(None, text, finish_reason)
# Use parser since output is not constrained by JSON schema
parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
@@ -891,13 +971,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
)
)
return tool_calls, text, finish_reason
return ToolCallProcessingResult(tool_calls, text, finish_reason)
except Exception as e:
logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request
return None, text, finish_reason
return ToolCallProcessingResult(None, text, finish_reason)
return None, text, finish_reason
return ToolCallProcessingResult(None, text, finish_reason)
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
@@ -990,13 +1070,25 @@ class OpenAIServingChat(OpenAIServingBase):
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
# Use JSON detector directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
parser_dict[index] = JsonArrayParser()
else:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
parser = parser_dict[index]
normal_text, calls = parser.parse_stream_chunk(delta)
# Handle both FunctionCallParser and JsonArrayParser
if isinstance(parser, JsonArrayParser):
result = parser.parse_streaming_increment(delta, request.tools)
normal_text, calls = result.normal_text, result.calls
else:
normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text
if normal_text:
@@ -1055,7 +1147,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _check_for_unstreamed_tool_args(
self,
parser: FunctionCallParser,
parser: Union[FunctionCallParser, JsonArrayParser],
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
@@ -1065,30 +1157,31 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Only check if we have tool calls and the parser has tracked data
# Get the detector - either from FunctionCallParser or directly if json detector
detector = parser.detector if hasattr(parser, "detector") else parser
# Only check if we have tool calls and the detector has tracked data
if (
not hasattr(parser.detector, "prev_tool_call_arr")
or not parser.detector.prev_tool_call_arr
not hasattr(detector, "prev_tool_call_arr")
or not detector.prev_tool_call_arr
):
return None
if (
not hasattr(parser.detector, "streamed_args_for_tool")
or not parser.detector.streamed_args_for_tool
not hasattr(detector, "streamed_args_for_tool")
or not detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
tool_index = len(detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = parser.detector.streamed_args_for_tool[tool_index]
actual_call = detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (

View File

@@ -20,6 +20,7 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
from sglang.srt.function_call.utils import get_json_schema_constraint
logger = logging.getLogger(__name__)
@@ -178,8 +179,8 @@ class FunctionCallParser:
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
ebnf = self.get_ebnf(tool_choice)
return ("ebnf", ebnf) if ebnf is not None else None
json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema)
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]

View File

@@ -0,0 +1,63 @@
import json
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class JsonArrayParser(BaseFormatDetector):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def __init__(self):
super().__init__()
# Configure for JSON array parsing
self.bot_token = "["
self.eot_token = "]"
self.tool_call_separator = ","
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return "[" in text or "{" in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parse JSON tool calls using the base class implementation.
"""
raise NotImplementedError(
"Detect and parse not supported for JSON schema constraints."
)
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError(
"EBNF generation is not supported for JSON schema constraints."
)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
return super().parse_streaming_increment(new_text, tools)
def structure_info(self) -> callable:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError("structure_info not used for JSON schema constraints")

View File

@@ -1,10 +1,13 @@
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any, Tuple
from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
@@ -37,10 +40,12 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
except (JSONDecodeError, IndexError) as e:
msg = getattr(e, "msg", str(e))
if "Extra data" in msg or "pop from empty list" in msg:
start = WHITESPACE.match(input_str, 0).end()
obj, end = JSONDecoder().raw_decode(input_str, start)
return obj, end
raise
@@ -50,3 +55,89 @@ def _is_complete_json(input_str: str) -> bool:
return True
except JSONDecodeError:
return False
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs = {}
for tool in tools:
if tool.function.parameters is None:
continue
defs = tool.function.parameters.get("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has "
"multiple schemas, which is not "
"supported."
)
else:
all_defs[def_name] = def_schema
return all_defs
def _get_tool_schema(tool: Tool) -> dict:
return {
"properties": {
"name": {"type": "string", "enum": [tool.function.name]},
"parameters": (
tool.function.parameters
if tool.function.parameters
else {"type": "object", "properties": {}}
),
},
"required": ["name", "parameters"],
}
def get_json_schema_constraint(
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[dict]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if isinstance(tool_choice, ToolChoice):
# For specific function choice, return the user's parameters schema directly
fn_name = tool_choice.function.name
for tool in tools:
if tool.function.name == fn_name:
return {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": _get_tool_schema(tool),
}
return None
elif tool_choice == "required":
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
return None