From ac3fae8445da9791ce95db9fc28db01976b68e4c Mon Sep 17 00:00:00 2001
From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com>
Date: Tue, 25 Mar 2025 14:15:25 +0900
Subject: [PATCH] [Feature] Support "strict" in function calling (#4310)
---
python/sglang/srt/function_call_parser.py | 155 +++++++++++++++-------
python/sglang/srt/openai_api/adapter.py | 25 +++-
python/sglang/srt/openai_api/protocol.py | 1 +
test/srt/test_function_calling.py | 59 +++++++-
4 files changed, 188 insertions(+), 52 deletions(-)
diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py
index 4ae8d0a0d..457258ff4 100644
--- a/python/sglang/srt/function_call_parser.py
+++ b/python/sglang/srt/function_call_parser.py
@@ -1,12 +1,21 @@
import json
import logging
import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
from json import JSONDecodeError, JSONDecoder
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
import partial_json_parser
+from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
+
+from sglang.srt.openai_api.protocol import (
+ StructuralTagResponseFormat,
+ StructuresResponseFormat,
+ Tool,
+)
logger = logging.getLogger(__name__)
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
]
-class Function(BaseModel):
- """Function Tool Template."""
-
- description: Optional[str] = Field(default=None, examples=[None])
- name: Optional[str] = None
- parameters: Optional[object] = None
-
-
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
@@ -74,7 +75,22 @@ class StreamingParseResult:
self.calls = calls or []
-class BaseFormatDetector:
+@dataclass
+class StructureInfo:
+ begin: str
+ end: str
+ trigger: str
+
+
+_GetInfoFunc = Callable[[str], StructureInfo]
+"""
+helper alias of function
+ususally it is a function that takes a name string and returns a StructureInfo object,
+which can be used to construct a structural_tag object
+"""
+
+
+class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
@@ -90,26 +106,12 @@ class BaseFormatDetector:
self.bot_token = ""
self.eot_token = ""
- def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
+ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
- name = action.get("name")
- if not name or name not in tool_indices:
- logger.warning(f"Model attempted to call undefined function: {name}")
- return []
-
- return [
- ToolCallItem(
- tool_index=tool_indices[name],
- name=name,
- parameters=json.dumps(
- action.get("parameters") or action.get("arguments", {}),
- ensure_ascii=False,
- ),
- )
- ]
+ action = [action]
results = []
for act in action:
@@ -125,12 +127,13 @@ class BaseFormatDetector:
),
)
)
+ else:
+ logger.warning(f"Model attempted to call undefined function: {name}")
return results
- def detect_and_parse(
- self, text: str, tools: List[Function]
- ) -> StreamingParseResult:
+ @abstractmethod
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
@@ -139,7 +142,7 @@ class BaseFormatDetector:
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment(
- self, new_text: str, tools: List[Function]
+ self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
@@ -198,7 +201,7 @@ class BaseFormatDetector:
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
- except partial_json_parser.core.exceptions.MalformedJSON:
+ except MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
@@ -304,6 +307,14 @@ class BaseFormatDetector:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
+ @abstractmethod
+ def has_tool_call(self, text: str) -> bool:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError()
+
class Qwen25Detector(BaseFormatDetector):
"""
@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector):
"""Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text
- def detect_and_parse(
- self, text: str, tools: List[Function]
- ) -> StreamingParseResult:
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector):
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='{"name":"' + name + '", "arguments":',
+ end="}",
+ trigger="",
+ )
+
class MistralDetector(BaseFormatDetector):
"""
@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector):
else:
return ""
- def detect_and_parse(
- self, text: str, tools: List[Function]
- ) -> StreamingParseResult:
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector):
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
+ end="}]",
+ trigger="[TOOL_CALLS]",
+ )
+
class Llama32Detector(BaseFormatDetector):
"""
@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector):
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text:
- _, action_text = text.split("<|python_tag|>")
+ normal_text, action_text = text.split("<|python_tag|>")
else:
- action_text = text
+ normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector):
calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
+ def structure_info(self) -> _GetInfoFunc:
+ return lambda name: StructureInfo(
+ begin='<|python_tag|>{"name":"' + name + '", "arguments":',
+ end="}",
+ trigger="<|python_tag|>",
+ )
+
class MultiFormatParser:
def __init__(self, detectors: List[BaseFormatDetector]):
@@ -458,7 +486,7 @@ class MultiFormatParser:
self.detectors = detectors
def parse_once(
- self, text: str, tools: List[Function]
+ self, text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]:
"""
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
@@ -480,7 +508,7 @@ class MultiFormatParser:
return final_normal_text, final_calls
def parse_streaming_increment(
- self, new_text: str, tools: List[Function]
+ self, new_text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]:
"""
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
@@ -512,13 +540,13 @@ class FunctionCallParser:
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
- ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
+ ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
}
- def __init__(self, tools: List[Function], tool_call_parser: str = None):
+ def __init__(self, tools: List[Tool], tool_call_parser: str):
detectors = []
if tool_call_parser:
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
@@ -563,3 +591,40 @@ class FunctionCallParser:
chunk_text, self.tools
)
return normal_text, calls
+
+ def structure_infos(self) -> List[_GetInfoFunc]:
+ """
+ Returns a list of structure_info functions for each detector
+ """
+ return [
+ detector.structure_info() for detector in self.multi_format_parser.detectors
+ ]
+
+ def get_structure_tag(self) -> StructuralTagResponseFormat:
+ tool_structures: List[StructuresResponseFormat] = list()
+ tool_trigger_set: Set[str] = set()
+
+ for wrapper in self.structure_infos():
+ for tool in self.tools:
+ function = tool.function
+ name = function.name
+ assert name is not None
+ info = wrapper(name)
+
+ # accept all if not strict, otherwise only accept the schema
+ schema = function.parameters if function.strict else {}
+
+ tool_structures.append(
+ StructuresResponseFormat(
+ begin=info.begin,
+ schema=schema, # type: ignore
+ end=info.end,
+ )
+ )
+ tool_trigger_set.add(info.trigger)
+
+ return StructuralTagResponseFormat(
+ type="structural_tag",
+ structures=tool_structures,
+ triggers=list(tool_trigger_set),
+ )
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
index da78de9c4..09e7bf3b4 100644
--- a/python/sglang/srt/openai_api/adapter.py
+++ b/python/sglang/srt/openai_api/adapter.py
@@ -20,7 +20,7 @@ import os
import time
import uuid
from http import HTTPStatus
-from typing import Dict, List
+from typing import Any, Dict, List, Set
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
generate_embedding_convs,
register_conv_template,
)
-from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
+from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
+ strict_tag = None
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
else:
tools = [item.function.model_dump() for item in request.tools]
+ tool_call_parser = tokenizer_manager.server_args.tool_call_parser
+ parser = FunctionCallParser(request.tools, tool_call_parser)
+ strict_tag = parser.get_structure_tag()
+
if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
+
+ if strict_tag is not None:
+ if (
+ sampling_params.get("regex")
+ or sampling_params.get("ebnf")
+ or sampling_params.get("structural_tag")
+ or sampling_params.get("json_schema")
+ ):
+ logger.warning(
+ "Constrained decoding is not compatible with tool calls."
+ )
+ else:
+ sampling_params["structural_tag"] = convert_json_schema_to_str(
+ strict_tag.model_dump(by_alias=True)
+ )
+
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py
index 1f88a4d13..c4b89c870 100644
--- a/python/sglang/srt/openai_api/protocol.py
+++ b/python/sglang/srt/openai_api/protocol.py
@@ -287,6 +287,7 @@ class Function(BaseModel):
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
+ strict: bool = False
class Tool(BaseModel):
diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py
index 24f341a5e..f422d5ea5 100644
--- a/test/srt/test_function_calling.py
+++ b/test/srt/test_function_calling.py
@@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase):
self.assertIn("a", args_obj, "Missing parameter 'a'")
self.assertIn("b", args_obj, "Missing parameter 'b'")
- self.assertEqual(
- args_obj["a"],
- 5,
- "Parameter a should be 5",
+ self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
+ self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
+
+ def test_function_call_strict(self):
+ """
+ Test: Whether the strict mode of function calling works as expected.
+ - When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
+ """
+ client = openai.Client(api_key=self.api_key, base_url=self.base_url)
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "sub",
+ "description": "Compute the difference of two integers",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "int_a": {
+ "type": "int",
+ "description": "First integer",
+ },
+ "int_b": {
+ "type": "int",
+ "description": "Second integer",
+ },
+ },
+ "required": ["int_a", "int_b"],
+ },
+ "strict": True,
+ },
+ }
+ ]
+
+ messages = [
+ {"role": "user", "content": "Please compute 5 - 7, using your tool."}
+ ]
+ response = client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ temperature=0.8,
+ top_p=0.8,
+ stream=False,
+ tools=tools,
)
- self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
+
+ tool_calls = response.choices[0].message.tool_calls
+ function_name = tool_calls[0].function.name
+ arguments = tool_calls[0].function.arguments
+ args_obj = json.loads(arguments)
+
+ self.assertEqual(function_name, "sub", "Function name should be 'sub'")
+ self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
+ self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
if __name__ == "__main__":