[Feature] Support "strict" in function calling (#4310)
This commit is contained in:
@@ -1,12 +1,21 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
|
||||
]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function Tool Template."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
@@ -74,7 +75,22 @@ class StreamingParseResult:
|
||||
self.calls = calls or []
|
||||
|
||||
|
||||
class BaseFormatDetector:
|
||||
@dataclass
|
||||
class StructureInfo:
|
||||
begin: str
|
||||
end: str
|
||||
trigger: str
|
||||
|
||||
|
||||
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||
"""
|
||||
helper alias of function
|
||||
ususally it is a function that takes a name string and returns a StructureInfo object,
|
||||
which can be used to construct a structural_tag object
|
||||
"""
|
||||
|
||||
|
||||
class BaseFormatDetector(ABC):
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
@@ -90,26 +106,12 @@ class BaseFormatDetector:
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
|
||||
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
|
||||
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||
tool_indices = {
|
||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||
}
|
||||
if not isinstance(action, list):
|
||||
name = action.get("name")
|
||||
if not name or name not in tool_indices:
|
||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||
return []
|
||||
|
||||
return [
|
||||
ToolCallItem(
|
||||
tool_index=tool_indices[name],
|
||||
name=name,
|
||||
parameters=json.dumps(
|
||||
action.get("parameters") or action.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
]
|
||||
action = [action]
|
||||
|
||||
results = []
|
||||
for act in action:
|
||||
@@ -125,12 +127,13 @@ class BaseFormatDetector:
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||
|
||||
return results
|
||||
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
@abstractmethod
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
@@ -139,7 +142,7 @@ class BaseFormatDetector:
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
@@ -198,7 +201,7 @@ class BaseFormatDetector:
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
except MalformedJSON:
|
||||
return StreamingParseResult()
|
||||
|
||||
if len(tool_call_arr) == 0:
|
||||
@@ -304,6 +307,14 @@ class BaseFormatDetector:
|
||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||
return StreamingParseResult()
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector):
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||
end="}</tool_call>",
|
||||
trigger="<tool_call>",
|
||||
)
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector):
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(
|
||||
self, text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector):
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||
end="}]",
|
||||
trigger="[TOOL_CALLS]",
|
||||
)
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector):
|
||||
# prefix the output with the <|python_tag|> token
|
||||
return "<|python_tag|>" in text or text.startswith("{")
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""Parse function calls from text, handling multiple JSON objects."""
|
||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||
return StreamingParseResult(normal_text=text, calls=[])
|
||||
|
||||
if "<|python_tag|>" in text:
|
||||
_, action_text = text.split("<|python_tag|>")
|
||||
normal_text, action_text = text.split("<|python_tag|>")
|
||||
else:
|
||||
action_text = text
|
||||
normal_text, action_text = "", text
|
||||
|
||||
# Split by semicolon and process each part
|
||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||
@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector):
|
||||
calls = self.parse_base_json(all_actions, tools)
|
||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||
|
||||
def structure_info(self) -> _GetInfoFunc:
|
||||
return lambda name: StructureInfo(
|
||||
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
||||
end="}",
|
||||
trigger="<|python_tag|>",
|
||||
)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||
@@ -458,7 +486,7 @@ class MultiFormatParser:
|
||||
self.detectors = detectors
|
||||
|
||||
def parse_once(
|
||||
self, text: str, tools: List[Function]
|
||||
self, text: str, tools: List[Tool]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||
@@ -480,7 +508,7 @@ class MultiFormatParser:
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> Tuple[str, list[ToolCallItem]]:
|
||||
"""
|
||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||
@@ -512,13 +540,13 @@ class FunctionCallParser:
|
||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||
"""
|
||||
|
||||
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
||||
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
||||
"llama3": Llama32Detector,
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
detectors = []
|
||||
if tool_call_parser:
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
@@ -563,3 +591,40 @@ class FunctionCallParser:
|
||||
chunk_text, self.tools
|
||||
)
|
||||
return normal_text, calls
|
||||
|
||||
def structure_infos(self) -> List[_GetInfoFunc]:
|
||||
"""
|
||||
Returns a list of structure_info functions for each detector
|
||||
"""
|
||||
return [
|
||||
detector.structure_info() for detector in self.multi_format_parser.detectors
|
||||
]
|
||||
|
||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||
tool_structures: List[StructuresResponseFormat] = list()
|
||||
tool_trigger_set: Set[str] = set()
|
||||
|
||||
for wrapper in self.structure_infos():
|
||||
for tool in self.tools:
|
||||
function = tool.function
|
||||
name = function.name
|
||||
assert name is not None
|
||||
info = wrapper(name)
|
||||
|
||||
# accept all if not strict, otherwise only accept the schema
|
||||
schema = function.parameters if function.strict else {}
|
||||
|
||||
tool_structures.append(
|
||||
StructuresResponseFormat(
|
||||
begin=info.begin,
|
||||
schema=schema, # type: ignore
|
||||
end=info.end,
|
||||
)
|
||||
)
|
||||
tool_trigger_set.add(info.trigger)
|
||||
|
||||
return StructuralTagResponseFormat(
|
||||
type="structural_tag",
|
||||
structures=tool_structures,
|
||||
triggers=list(tool_trigger_set),
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
|
||||
generate_embedding_convs,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# - audio_data: None or a list of audio strings (URLs).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
strict_tag = None
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
tools = None
|
||||
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
strict_tag = parser.get_structure_tag()
|
||||
|
||||
if chat_template_name is None:
|
||||
openai_compatible_messages = []
|
||||
for message in request.messages:
|
||||
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
if strict_tag is not None:
|
||||
if (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
):
|
||||
logger.warning(
|
||||
"Constrained decoding is not compatible with tool calls."
|
||||
)
|
||||
else:
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
strict_tag.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
|
||||
@@ -287,6 +287,7 @@ class Function(BaseModel):
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user