From ac3fae8445da9791ce95db9fc28db01976b68e4c Mon Sep 17 00:00:00 2001 From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Date: Tue, 25 Mar 2025 14:15:25 +0900 Subject: [PATCH] [Feature] Support "strict" in function calling (#4310) --- python/sglang/srt/function_call_parser.py | 155 +++++++++++++++------- python/sglang/srt/openai_api/adapter.py | 25 +++- python/sglang/srt/openai_api/protocol.py | 1 + test/srt/test_function_calling.py | 59 +++++++- 4 files changed, 188 insertions(+), 52 deletions(-) diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py index 4ae8d0a0d..457258ff4 100644 --- a/python/sglang/srt/function_call_parser.py +++ b/python/sglang/srt/function_call_parser.py @@ -1,12 +1,21 @@ import json import logging import re +from abc import ABC, abstractmethod +from dataclasses import dataclass from json import JSONDecodeError, JSONDecoder -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type import partial_json_parser +from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.options import Allow -from pydantic import BaseModel, Field +from pydantic import BaseModel + +from sglang.srt.openai_api.protocol import ( + StructuralTagResponseFormat, + StructuresResponseFormat, + Tool, +) logger = logging.getLogger(__name__) @@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [ ] -class Function(BaseModel): - """Function Tool Template.""" - - description: Optional[str] = Field(default=None, examples=[None]) - name: Optional[str] = None - parameters: Optional[object] = None - - class ToolCallItem(BaseModel): """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" @@ -74,7 +75,22 @@ class StreamingParseResult: self.calls = calls or [] -class BaseFormatDetector: +@dataclass +class StructureInfo: + begin: str + end: str + trigger: str + + +_GetInfoFunc = Callable[[str], StructureInfo] +""" +helper alias of function +ususally it is a function that takes a name string and returns a StructureInfo object, +which can be used to construct a structural_tag object +""" + + +class BaseFormatDetector(ABC): """Base class providing two sets of interfaces: one-time and streaming incremental.""" def __init__(self): @@ -90,26 +106,12 @@ class BaseFormatDetector: self.bot_token = "" self.eot_token = "" - def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]: + def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: tool_indices = { tool.function.name: i for i, tool in enumerate(tools) if tool.function.name } if not isinstance(action, list): - name = action.get("name") - if not name or name not in tool_indices: - logger.warning(f"Model attempted to call undefined function: {name}") - return [] - - return [ - ToolCallItem( - tool_index=tool_indices[name], - name=name, - parameters=json.dumps( - action.get("parameters") or action.get("arguments", {}), - ensure_ascii=False, - ), - ) - ] + action = [action] results = [] for act in action: @@ -125,12 +127,13 @@ class BaseFormatDetector: ), ) ) + else: + logger.warning(f"Model attempted to call undefined function: {name}") return results - def detect_and_parse( - self, text: str, tools: List[Function] - ) -> StreamingParseResult: + @abstractmethod + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """ Parses the text in one go. Returns success=True if the format matches, otherwise False. Note that leftover_text here represents "content that this parser will not consume further". @@ -139,7 +142,7 @@ class BaseFormatDetector: return StreamingParseResult(calls=self.parse_base_json(action, tools)) def parse_streaming_increment( - self, new_text: str, tools: List[Function] + self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: """ Streaming incremental parsing with tool validation. @@ -198,7 +201,7 @@ class BaseFormatDetector: obj["arguments"] = obj["parameters"] tool_call_arr.append(obj) - except partial_json_parser.core.exceptions.MalformedJSON: + except MalformedJSON: return StreamingParseResult() if len(tool_call_arr) == 0: @@ -304,6 +307,14 @@ class BaseFormatDetector: logger.error(f"Error in parse_streaming_increment: {e}") return StreamingParseResult() + @abstractmethod + def has_tool_call(self, text: str) -> bool: + raise NotImplementedError() + + @abstractmethod + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() + class Qwen25Detector(BaseFormatDetector): """ @@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector): """Check if the text contains a Qwen 2.5 format tool call.""" return self.bot_token in text - def detect_and_parse( - self, text: str, tools: List[Function] - ) -> StreamingParseResult: + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. @@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector): calls.extend(self.parse_base_json(match_result, tools)) return StreamingParseResult(normal_text=normal_text, calls=calls) + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='{"name":"' + name + '", "arguments":', + end="}", + trigger="", + ) + class MistralDetector(BaseFormatDetector): """ @@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector): else: return "" - def detect_and_parse( - self, text: str, tools: List[Function] - ) -> StreamingParseResult: + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """ One-time parsing: Detects and parses tool calls in the provided text. @@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector): calls.extend(self.parse_base_json(match_result, tools)) return StreamingParseResult(normal_text=normal_text, calls=calls) + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', + end="}]", + trigger="[TOOL_CALLS]", + ) + class Llama32Detector(BaseFormatDetector): """ @@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector): # prefix the output with the <|python_tag|> token return "<|python_tag|>" in text or text.startswith("{") - def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: """Parse function calls from text, handling multiple JSON objects.""" if "<|python_tag|>" not in text and not text.startswith("{"): return StreamingParseResult(normal_text=text, calls=[]) if "<|python_tag|>" in text: - _, action_text = text.split("<|python_tag|>") + normal_text, action_text = text.split("<|python_tag|>") else: - action_text = text + normal_text, action_text = "", text # Split by semicolon and process each part json_parts = [part.strip() for part in action_text.split(";") if part.strip()] @@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector): calls = self.parse_base_json(all_actions, tools) return StreamingParseResult(normal_text=normal_text, calls=calls) + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin='<|python_tag|>{"name":"' + name + '", "arguments":', + end="}", + trigger="<|python_tag|>", + ) + class MultiFormatParser: def __init__(self, detectors: List[BaseFormatDetector]): @@ -458,7 +486,7 @@ class MultiFormatParser: self.detectors = detectors def parse_once( - self, text: str, tools: List[Function] + self, text: str, tools: List[Tool] ) -> Tuple[str, list[ToolCallItem]]: """ One-time parsing: Loop through detectors until there are no new matches or text is exhausted @@ -480,7 +508,7 @@ class MultiFormatParser: return final_normal_text, final_calls def parse_streaming_increment( - self, new_text: str, tools: List[Function] + self, new_text: str, tools: List[Tool] ) -> Tuple[str, list[ToolCallItem]]: """ Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment @@ -512,13 +540,13 @@ class FunctionCallParser: and returns the resulting normal_text and calls to the upper layer (or SSE). """ - ToolCallParserEnum: Dict[str, BaseFormatDetector] = { + ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "llama3": Llama32Detector, "qwen25": Qwen25Detector, "mistral": MistralDetector, } - def __init__(self, tools: List[Function], tool_call_parser: str = None): + def __init__(self, tools: List[Tool], tool_call_parser: str): detectors = [] if tool_call_parser: detector_class = self.ToolCallParserEnum.get(tool_call_parser) @@ -563,3 +591,40 @@ class FunctionCallParser: chunk_text, self.tools ) return normal_text, calls + + def structure_infos(self) -> List[_GetInfoFunc]: + """ + Returns a list of structure_info functions for each detector + """ + return [ + detector.structure_info() for detector in self.multi_format_parser.detectors + ] + + def get_structure_tag(self) -> StructuralTagResponseFormat: + tool_structures: List[StructuresResponseFormat] = list() + tool_trigger_set: Set[str] = set() + + for wrapper in self.structure_infos(): + for tool in self.tools: + function = tool.function + name = function.name + assert name is not None + info = wrapper(name) + + # accept all if not strict, otherwise only accept the schema + schema = function.parameters if function.strict else {} + + tool_structures.append( + StructuresResponseFormat( + begin=info.begin, + schema=schema, # type: ignore + end=info.end, + ) + ) + tool_trigger_set.add(info.trigger) + + return StructuralTagResponseFormat( + type="structural_tag", + structures=tool_structures, + triggers=list(tool_trigger_set), + ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index da78de9c4..09e7bf3b4 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -20,7 +20,7 @@ import os import time import uuid from http import HTTPStatus -from typing import Dict, List +from typing import Any, Dict, List, Set from fastapi import HTTPException, Request, UploadFile from fastapi.responses import ORJSONResponse, StreamingResponse @@ -38,7 +38,7 @@ from sglang.srt.conversation import ( generate_embedding_convs, register_conv_template, ) -from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser +from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchRequest, @@ -915,6 +915,7 @@ def v1_chat_generate_request( # - image_data: None or a list of image strings (URLs or base64 strings). # - audio_data: None or a list of audio strings (URLs). # None skips any image processing in GenerateReqInput. + strict_tag = None if not isinstance(request.messages, str): # Apply chat template and its stop strings. tools = None @@ -929,6 +930,10 @@ def v1_chat_generate_request( else: tools = [item.function.model_dump() for item in request.tools] + tool_call_parser = tokenizer_manager.server_args.tool_call_parser + parser = FunctionCallParser(request.tools, tool_call_parser) + strict_tag = parser.get_structure_tag() + if chat_template_name is None: openai_compatible_messages = [] for message in request.messages: @@ -1036,6 +1041,22 @@ def v1_chat_generate_request( sampling_params["structural_tag"] = convert_json_schema_to_str( request.response_format.model_dump(by_alias=True) ) + + if strict_tag is not None: + if ( + sampling_params.get("regex") + or sampling_params.get("ebnf") + or sampling_params.get("structural_tag") + or sampling_params.get("json_schema") + ): + logger.warning( + "Constrained decoding is not compatible with tool calls." + ) + else: + sampling_params["structural_tag"] = convert_json_schema_to_str( + strict_tag.model_dump(by_alias=True) + ) + sampling_params_list.append(sampling_params) image_data_list.append(image_data) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 1f88a4d13..c4b89c870 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -287,6 +287,7 @@ class Function(BaseModel): description: Optional[str] = Field(default=None, examples=[None]) name: Optional[str] = None parameters: Optional[object] = None + strict: bool = False class Tool(BaseModel): diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py index 24f341a5e..f422d5ea5 100644 --- a/test/srt/test_function_calling.py +++ b/test/srt/test_function_calling.py @@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase): self.assertIn("a", args_obj, "Missing parameter 'a'") self.assertIn("b", args_obj, "Missing parameter 'b'") - self.assertEqual( - args_obj["a"], - 5, - "Parameter a should be 5", + self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5") + self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7") + + def test_function_call_strict(self): + """ + Test: Whether the strict mode of function calling works as expected. + - When strict mode is enabled, the AI should not return a function call if the function name is not recognized. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "sub", + "description": "Compute the difference of two integers", + "parameters": { + "type": "object", + "properties": { + "int_a": { + "type": "int", + "description": "First integer", + }, + "int_b": { + "type": "int", + "description": "Second integer", + }, + }, + "required": ["int_a", "int_b"], + }, + "strict": True, + }, + } + ] + + messages = [ + {"role": "user", "content": "Please compute 5 - 7, using your tool."} + ] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, ) - self.assertEqual(args_obj["b"], 7, "Parameter b should be 7") + + tool_calls = response.choices[0].message.tool_calls + function_name = tool_calls[0].function.name + arguments = tool_calls[0].function.arguments + args_obj = json.loads(arguments) + + self.assertEqual(function_name, "sub", "Function name should be 'sub'") + self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5") + self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7") if __name__ == "__main__":