diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 7bf14bfc2..d00ba1428 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -13,6 +13,7 @@ # ============================================================================== """Constrained decoding with xgrammar backend.""" +import json import logging from typing import List, Tuple @@ -21,6 +22,7 @@ from xgrammar import ( CompiledGrammar, GrammarCompiler, GrammarMatcher, + StructuralTagItem, TokenizerInfo, allocate_token_bitmask, apply_token_bitmask_inplace, @@ -138,6 +140,23 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") return None + elif key_type == "structural_tag": + try: + structural_tag = json.loads(key_string) + tags = [ + StructuralTagItem( + begin=structure["begin"], + schema=json.dumps(structure["schema"]), + end=structure["end"], + ) + for structure in structural_tag["structures"] + ] + ctx = self.grammar_compiler.compile_structural_tag( + tags, structural_tag["triggers"] + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None else: raise ValueError(f"Invalid key_type: {key_type}") diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6ef3f8ebe..ef4d29287 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -710,6 +710,7 @@ class Scheduler: req.sampling_params.json_schema is not None or req.sampling_params.regex is not None or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None ): assert self.grammar_backend is not None if req.sampling_params.json_schema is not None: @@ -718,6 +719,8 @@ class Scheduler: key = ("regex", req.sampling_params.regex) elif req.sampling_params.ebnf is not None: key = ("ebnf", req.sampling_params.ebnf) + elif req.sampling_params.structural_tag: + key = ("structural_tag", req.sampling_params.structural_tag) req.grammar = self.grammar_backend.get_cached_value(key) if not req.grammar: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 0556f852a..7c385d40b 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -994,10 +994,17 @@ def v1_chat_generate_request( "ignore_eos": request.ignore_eos, "skip_special_tokens": request.skip_special_tokens, } + if request.response_format and request.response_format.type == "json_schema": sampling_params["json_schema"] = convert_json_schema_to_str( request.response_format.json_schema.schema_ ) + elif ( + request.response_format and request.response_format.type == "structural_tag" + ): + sampling_params["structural_tag"] = convert_json_schema_to_str( + request.response_format.model_dump(by_alias=True) + ) sampling_params_list.append(sampling_params) image_data_list.append(image_data) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 95b34527e..5f1ba431a 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -258,6 +258,18 @@ class ResponseFormat(BaseModel): json_schema: Optional[JsonSchemaResponseFormat] = None +class StructuresResponseFormat(BaseModel): + begin: str + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + end: str + + +class StructuralTagResponseFormat(BaseModel): + type: Literal["structural_tag"] + structures: List[StructuresResponseFormat] + triggers: List[str] + + class Function(BaseModel): """Function descriptions.""" @@ -298,7 +310,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None n: int = 1 presence_penalty: float = 0.0 - response_format: Optional[ResponseFormat] = None + response_format: Union[ResponseFormat, StructuralTagResponseFormat] = None seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index d82a0f282..a478be2ce 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -45,6 +45,7 @@ class SamplingParams: json_schema: Optional[str] = None, regex: Optional[str] = None, ebnf: Optional[str] = None, + structural_tag: Optional[str] = None, no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, @@ -72,6 +73,7 @@ class SamplingParams: self.n = n self.json_schema = json_schema self.ebnf = ebnf + self.structural_tag = structural_tag self.no_stop_trim = no_stop_trim self.return_hidden_states = return_hidden_states self.custom_params = custom_params