[Feature] New structural tag support (#10691)
This commit is contained in:
@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend):
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
assert is_legacy_structural_tag(structural_tag)
|
||||
tags = [
|
||||
StructTag(
|
||||
begin=structure["begin"],
|
||||
|
||||
12
python/sglang/srt/constrained/utils.py
Normal file
12
python/sglang/srt/constrained/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def is_legacy_structural_tag(obj: Dict) -> bool:
|
||||
# test whether an object is a legacy structural tag
|
||||
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
|
||||
if obj.get("structures", None) is not None:
|
||||
assert obj.get("triggers", None) is not None
|
||||
return True
|
||||
else:
|
||||
assert obj.get("format", None) is not None
|
||||
return False
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarObject,
|
||||
GrammarStats,
|
||||
)
|
||||
from sglang.srt.constrained.utils import is_legacy_structural_tag
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -241,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
if is_legacy_structural_tag(structural_tag):
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_structural_tag(key_string)
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
@@ -37,6 +37,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
from xgrammar import StructuralTag
|
||||
|
||||
from sglang.utils import convert_json_schema_to_str
|
||||
|
||||
@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
# NOTE(dark): keep this for backward compatibility
|
||||
class LegacyStructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
StructuralTagResponseFormat: TypeAlias = Union[
|
||||
LegacyStructuralTagResponseFormat, StructuralTag
|
||||
]
|
||||
|
||||
ToolCallConstraint: TypeAlias = Union[
|
||||
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
|
||||
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
|
||||
]
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
self,
|
||||
stop: List[str],
|
||||
model_generation_config: Dict[str, Any],
|
||||
tool_call_constraint: Optional[Any] = None,
|
||||
tool_call_constraint: Optional[ToolCallConstraint] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert request to sampling parameters.
|
||||
@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
constraint_value # type: ignore
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
@@ -1177,7 +1189,7 @@ class MessageProcessingResult:
|
||||
video_data: Optional[Any]
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
tool_call_constraint: Optional[ToolCallConstraint] = None
|
||||
|
||||
|
||||
class ToolCallProcessingResult(NamedTuple):
|
||||
|
||||
@@ -2,9 +2,10 @@ import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
StructuralTagResponseFormat,
|
||||
LegacyStructuralTagResponseFormat,
|
||||
StructuresResponseFormat,
|
||||
Tool,
|
||||
ToolCallConstraint,
|
||||
ToolChoice,
|
||||
)
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
@@ -51,7 +52,6 @@ class FunctionCallParser:
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||
detector: Type[BaseFormatDetector] = None
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detector = detector_class()
|
||||
@@ -123,7 +123,7 @@ class FunctionCallParser:
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||
def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
|
||||
"""
|
||||
Generate a structural tag response format for all available tools.
|
||||
|
||||
@@ -151,7 +151,9 @@ class FunctionCallParser:
|
||||
)
|
||||
tool_trigger_set.add(info.trigger)
|
||||
|
||||
return StructuralTagResponseFormat(
|
||||
# TODO(dark): move this into new structural tag format
|
||||
# This requires all grammar backend support the new format
|
||||
return LegacyStructuralTagResponseFormat(
|
||||
type="structural_tag",
|
||||
structures=tool_structures,
|
||||
triggers=list(tool_trigger_set),
|
||||
@@ -159,7 +161,7 @@ class FunctionCallParser:
|
||||
|
||||
def get_structure_constraint(
|
||||
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
|
||||
) -> Optional[Tuple[str, Any]]:
|
||||
) -> Optional[ToolCallConstraint]:
|
||||
"""
|
||||
Returns the appropriate structure constraint for tool calls based on the tool_choice.
|
||||
The constraint is used to guide the model's output format.
|
||||
@@ -178,8 +180,8 @@ class FunctionCallParser:
|
||||
and tool_choice == "auto"
|
||||
and any(tool.function.strict for tool in self.tools)
|
||||
):
|
||||
strict_tag = self.get_structure_tag()
|
||||
return ("structural_tag", strict_tag)
|
||||
tag = self.get_structure_tag()
|
||||
return ("structural_tag", tag)
|
||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||
return ("json_schema", json_schema)
|
||||
|
||||
Reference in New Issue
Block a user