[Feature] New structural tag support (#10691)

This commit is contained in:
DarkSharpness
2025-10-20 18:25:58 +08:00
committed by GitHub
parent 296f689242
commit 276e7b3e4e
7 changed files with 326 additions and 22 deletions

View File

@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.constrained.utils import is_legacy_structural_tag
logger = logging.getLogger(__name__)
@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend):
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
structural_tag = json.loads(key_string)
assert is_legacy_structural_tag(structural_tag)
tags = [
StructTag(
begin=structure["begin"],

View File

@@ -0,0 +1,12 @@
from typing import Dict
def is_legacy_structural_tag(obj: Dict) -> bool:
# test whether an object is a legacy structural tag
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
if obj.get("structures", None) is not None:
assert obj.get("triggers", None) is not None
return True
else:
assert obj.get("format", None) is not None
return False

View File

@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarObject,
GrammarStats,
)
from sglang.srt.constrained.utils import is_legacy_structural_tag
from sglang.srt.utils import is_hip
_is_hip = is_hip()
@@ -241,18 +242,22 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
if is_legacy_structural_tag(structural_tag):
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
else:
ctx = self.grammar_compiler.compile_structural_tag(key_string)
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ

View File

@@ -17,7 +17,7 @@ import logging
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
@@ -37,6 +37,7 @@ from pydantic import (
model_validator,
)
from typing_extensions import Literal
from xgrammar import StructuralTag
from sglang.utils import convert_json_schema_to_str
@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
end: str
class StructuralTagResponseFormat(BaseModel):
# NOTE(dark): keep this for backward compatibility
class LegacyStructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
StructuralTagResponseFormat: TypeAlias = Union[
LegacyStructuralTagResponseFormat, StructuralTag
]
ToolCallConstraint: TypeAlias = Union[
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
]
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel):
self,
stop: List[str],
model_generation_config: Dict[str, Any],
tool_call_constraint: Optional[Any] = None,
tool_call_constraint: Optional[ToolCallConstraint] = None,
) -> Dict[str, Any]:
"""
Convert request to sampling parameters.
@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel):
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
constraint_value # type: ignore
)
else:
sampling_params[constraint_type] = constraint_value
@@ -1177,7 +1189,7 @@ class MessageProcessingResult:
video_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
tool_call_constraint: Optional[ToolCallConstraint] = None
class ToolCallProcessingResult(NamedTuple):

View File

@@ -2,9 +2,10 @@ import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat,
LegacyStructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolCallConstraint,
ToolChoice,
)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
@@ -51,7 +52,6 @@ class FunctionCallParser:
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
detector: Type[BaseFormatDetector] = None
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detector = detector_class()
@@ -123,7 +123,7 @@ class FunctionCallParser:
return final_normal_text, final_calls
def get_structure_tag(self) -> StructuralTagResponseFormat:
def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
"""
Generate a structural tag response format for all available tools.
@@ -151,7 +151,9 @@ class FunctionCallParser:
)
tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat(
# TODO(dark): move this into new structural tag format
# This requires all grammar backend support the new format
return LegacyStructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
@@ -159,7 +161,7 @@ class FunctionCallParser:
def get_structure_constraint(
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
) -> Optional[Tuple[str, Any]]:
) -> Optional[ToolCallConstraint]:
"""
Returns the appropriate structure constraint for tool calls based on the tool_choice.
The constraint is used to guide the model's output format.
@@ -178,8 +180,8 @@ class FunctionCallParser:
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
tag = self.get_structure_tag()
return ("structural_tag", tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema)