From 2269cf1e2faf1c1501f5c48fd6e63c050d24cf13 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Sep 2025 12:55:55 -0700 Subject: [PATCH] [Auto Sync] Update base_grammar_backend.py, llguidance_back... (20250911) (#10333) Co-authored-by: github-actions[bot] --- .../srt/constrained/base_grammar_backend.py | 60 +++++++++++++++---- .../srt/constrained/llguidance_backend.py | 1 - .../srt/constrained/outlines_backend.py | 1 - .../srt/constrained/xgrammar_backend.py | 35 +++++++---- .../decode_schedule_batch_mixin.py | 5 +- 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 4fe5d6c77..23bcd1bd3 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -14,8 +14,9 @@ """The baseclass of a backend for grammar-guided constrained decoding.""" import logging +import time from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from threading import Event from typing import Dict, List, Optional, Tuple @@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) +@dataclass +class GrammarStats: + compilation_time: Optional[float] = None + schema_count: Optional[int] = None + ebnf_size: Optional[int] = None + is_cache_hit: bool = False + is_grammar_aborted: bool = False + tree_traversal_time: List[float] = field(default_factory=list) + + class BaseGrammarObject: def __init__(self): self._finished = False + self.grammar_stats = None + self.current_token = None def accept_token(self, token: int) -> None: """ @@ -137,19 +150,26 @@ class BaseGrammarBackend: return self._not_supported("structural_tag", key_string) def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: + s = time.perf_counter() key_type, key_string = key if key_type == "json": - return self.dispatch_json(key_string) + grammar = self.dispatch_json(key_string) elif key_type == "regex": - return self.dispatch_regex(key_string) + grammar = self.dispatch_regex(key_string) elif key_type == "ebnf": - return self.dispatch_ebnf(key_string) + grammar = self.dispatch_ebnf(key_string) elif key_type == "structural_tag": - return self.dispatch_structural_tag(key_string) + grammar = self.dispatch_structural_tag(key_string) elif key_type == "structural_pattern": - return self.dispatch_structural_pattern(key_string) + grammar = self.dispatch_structural_pattern(key_string) + elif key_type == "structural_pattern_v2": + grammar = self.dispatch_structural_pattern_v2(key_string) else: - return self.dispatch_fallback(key_type, key_string) + grammar = self.dispatch_fallback(key_type, key_string) + + if grammar is not None and grammar.grammar_stats is not None: + grammar.grammar_stats.compilation_time = time.perf_counter() - s + return grammar def get_cached_or_future_value( self, key: Tuple[str, str] @@ -167,20 +187,36 @@ class BaseGrammarBackend: self.cache.clear() +GRAMMAR_BACKEND_REGISTRY = {} + + +def register_grammar_backend(name, init_func): + GRAMMAR_BACKEND_REGISTRY[name] = init_func + + def create_grammar_backend( server_args: ServerArgs, tokenizer, vocab_size: int, eos_token_ids: Optional[set] = None, ) -> Optional[BaseGrammarBackend]: - if server_args.grammar_backend == "outlines": + name = server_args.grammar_backend + + # Custom grammar backend has the highest priority + if name in GRAMMAR_BACKEND_REGISTRY: + return GRAMMAR_BACKEND_REGISTRY[name]( + server_args, tokenizer, vocab_size, eos_token_ids + ) + + # Default grammar backends + if name == "outlines": from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend grammar_backend = OutlinesGrammarBackend( tokenizer, whitespace_pattern=server_args.constrained_json_whitespace_pattern, ) - elif server_args.grammar_backend == "xgrammar": + elif name == "xgrammar": from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend # Convert Set[int] to List[int] if needed @@ -189,17 +225,17 @@ def create_grammar_backend( grammar_backend = XGrammarGrammarBackend( tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list ) - elif server_args.grammar_backend == "llguidance": + elif name == "llguidance": from sglang.srt.constrained.llguidance_backend import GuidanceBackend grammar_backend = GuidanceBackend( tokenizer=tokenizer, whitespace_pattern=server_args.constrained_json_whitespace_pattern, ) - elif server_args.grammar_backend == "none": + elif name == "none": return None else: - raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") + raise ValueError(f"Invalid grammar backend: {name}") if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"): from sglang.srt.constrained.reasoner_grammar_backend import ( diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py index 2acbf2c51..5e29c2524 100644 --- a/python/sglang/srt/constrained/llguidance_backend.py +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject): self.serialized_grammar, log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), ) - self.finished = False self.bitmask = None def accept_token(self, token: int): diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 5302fadaa..b54e34b3d 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject): self.guide = guide self.jump_forward_map = jump_forward_map self.state = 0 - self.finished = False def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 7b101df4f..8ff55d261 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -13,6 +13,7 @@ # ============================================================================== """Constrained decoding with xgrammar backend.""" +import dataclasses import json import logging from typing import List, Optional, Tuple, Union @@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import ( INVALID_GRAMMAR_OBJ, BaseGrammarBackend, BaseGrammarObject, + GrammarStats, ) from sglang.srt.utils import is_hip @@ -41,9 +43,9 @@ else: from sglang.srt.constrained.triton_ops.bitmask_ops import ( apply_token_bitmask_inplace_triton, ) + + logger = logging.getLogger(__name__) - - MAX_ROLLBACK_TOKENS = 200 @@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject): ctx: CompiledGrammar, override_stop_tokens: Optional[Union[List[int], int]], key_string: Optional[str] = None, # TODO (sk): for debugging, remove later + grammar_stats: Optional[GrammarStats] = GrammarStats(), ) -> None: + super().__init__() self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx self.override_stop_tokens = override_stop_tokens - self.finished = False self.accepted_tokens = [] self.key_string = key_string + self.grammar_stats = grammar_stats def accept_token(self, token: int): if not self.is_terminated(): + self.current_token = token accepted = self.matcher.accept_token(token) if not accepted: # log for debugging @@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject): self.ctx, self.override_stop_tokens, self.key_string, + dataclasses.replace( + self.grammar_stats, is_cache_hit=True, tree_traversal_time=[] + ), ) def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: @@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject): assert self.matcher.accept_token(new_output_ids[i]) def __repr__(self): - return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})" + return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})" class XGrammarGrammarBackend(BaseGrammarBackend): @@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend): self.vocab_size = vocab_size self.override_stop_tokens = override_stop_tokens - def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar: + def _from_context( + self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats + ) -> XGrammarGrammar: matcher = GrammarMatcher( ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, override_stop_tokens=self.override_stop_tokens, ) return XGrammarGrammar( - matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string + matcher, + self.vocab_size, + ctx, + self.override_stop_tokens, + key_string, + grammar_stats, ) def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: @@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except (RuntimeError, json.decoder.JSONDecodeError) as e: logging.error(f"Hit invalid json_schema: {key_string=}, {e=}") return INVALID_GRAMMAR_OBJ - return self._from_context(ctx, key_string) + return self._from_context(ctx, key_string, GrammarStats()) def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.error(f"Hit invalid ebnf: {key_string=}, {e=}") return INVALID_GRAMMAR_OBJ - return self._from_context(ctx, key_string) + return self._from_context(ctx, key_string, GrammarStats()) def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.error(f"Hit invalid regex: {key_string=}, {e=}") return INVALID_GRAMMAR_OBJ - return self._from_context(ctx, key_string) + return self._from_context(ctx, key_string, GrammarStats()) def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except (RuntimeError, json.decoder.JSONDecodeError) as e: logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}") return INVALID_GRAMMAR_OBJ - return self._from_context(ctx, key_string) + return self._from_context(ctx, key_string, GrammarStats()) def reset(self): self.grammar_compiler.clear_cache() diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index c1cb17c04..be0383eec 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin: if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. try: - req.grammar.accept_token(req.output_ids[-1]) + # if it is not None, then the grammar is from a retracted request, and we should not + # accept the token as it's already accepted + if req.grammar.current_token is None: + req.grammar.accept_token(req.output_ids[-1]) except ValueError as e: # Grammar accept_token can raise ValueError if the token is not in the grammar. # This can happen if the grammar is not set correctly or the token is invalid.