[Auto Sync] Update base_grammar_backend.py, llguidance_back... (20250911) (#10333)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,8 +14,9 @@
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarStats:
|
||||
compilation_time: Optional[float] = None
|
||||
schema_count: Optional[int] = None
|
||||
ebnf_size: Optional[int] = None
|
||||
is_cache_hit: bool = False
|
||||
is_grammar_aborted: bool = False
|
||||
tree_traversal_time: List[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class BaseGrammarObject:
|
||||
|
||||
def __init__(self):
|
||||
self._finished = False
|
||||
self.grammar_stats = None
|
||||
self.current_token = None
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
@@ -137,19 +150,26 @@ class BaseGrammarBackend:
|
||||
return self._not_supported("structural_tag", key_string)
|
||||
|
||||
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
s = time.perf_counter()
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.dispatch_json(key_string)
|
||||
grammar = self.dispatch_json(key_string)
|
||||
elif key_type == "regex":
|
||||
return self.dispatch_regex(key_string)
|
||||
grammar = self.dispatch_regex(key_string)
|
||||
elif key_type == "ebnf":
|
||||
return self.dispatch_ebnf(key_string)
|
||||
grammar = self.dispatch_ebnf(key_string)
|
||||
elif key_type == "structural_tag":
|
||||
return self.dispatch_structural_tag(key_string)
|
||||
grammar = self.dispatch_structural_tag(key_string)
|
||||
elif key_type == "structural_pattern":
|
||||
return self.dispatch_structural_pattern(key_string)
|
||||
grammar = self.dispatch_structural_pattern(key_string)
|
||||
elif key_type == "structural_pattern_v2":
|
||||
grammar = self.dispatch_structural_pattern_v2(key_string)
|
||||
else:
|
||||
return self.dispatch_fallback(key_type, key_string)
|
||||
grammar = self.dispatch_fallback(key_type, key_string)
|
||||
|
||||
if grammar is not None and grammar.grammar_stats is not None:
|
||||
grammar.grammar_stats.compilation_time = time.perf_counter() - s
|
||||
return grammar
|
||||
|
||||
def get_cached_or_future_value(
|
||||
self, key: Tuple[str, str]
|
||||
@@ -167,20 +187,36 @@ class BaseGrammarBackend:
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
GRAMMAR_BACKEND_REGISTRY = {}
|
||||
|
||||
|
||||
def register_grammar_backend(name, init_func):
|
||||
GRAMMAR_BACKEND_REGISTRY[name] = init_func
|
||||
|
||||
|
||||
def create_grammar_backend(
|
||||
server_args: ServerArgs,
|
||||
tokenizer,
|
||||
vocab_size: int,
|
||||
eos_token_ids: Optional[set] = None,
|
||||
) -> Optional[BaseGrammarBackend]:
|
||||
if server_args.grammar_backend == "outlines":
|
||||
name = server_args.grammar_backend
|
||||
|
||||
# Custom grammar backend has the highest priority
|
||||
if name in GRAMMAR_BACKEND_REGISTRY:
|
||||
return GRAMMAR_BACKEND_REGISTRY[name](
|
||||
server_args, tokenizer, vocab_size, eos_token_ids
|
||||
)
|
||||
|
||||
# Default grammar backends
|
||||
if name == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
||||
|
||||
grammar_backend = OutlinesGrammarBackend(
|
||||
tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
elif name == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||
|
||||
# Convert Set[int] to List[int] if needed
|
||||
@@ -189,17 +225,17 @@ def create_grammar_backend(
|
||||
grammar_backend = XGrammarGrammarBackend(
|
||||
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
||||
)
|
||||
elif server_args.grammar_backend == "llguidance":
|
||||
elif name == "llguidance":
|
||||
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
||||
|
||||
grammar_backend = GuidanceBackend(
|
||||
tokenizer=tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
elif server_args.grammar_backend == "none":
|
||||
elif name == "none":
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
raise ValueError(f"Invalid grammar backend: {name}")
|
||||
|
||||
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
||||
from sglang.srt.constrained.reasoner_grammar_backend import (
|
||||
|
||||
@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def accept_token(self, token: int):
|
||||
|
||||
@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
self.guide = guide
|
||||
self.jump_forward_map = jump_forward_map
|
||||
self.state = 0
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with xgrammar backend."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
GrammarStats,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
@@ -41,9 +43,9 @@ else:
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_ROLLBACK_TOKENS = 200
|
||||
|
||||
|
||||
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
||||
grammar_stats: Optional[GrammarStats] = GrammarStats(),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
self.accepted_tokens = []
|
||||
self.key_string = key_string
|
||||
self.grammar_stats = grammar_stats
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if not self.is_terminated():
|
||||
self.current_token = token
|
||||
accepted = self.matcher.accept_token(token)
|
||||
if not accepted:
|
||||
# log for debugging
|
||||
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
self.ctx,
|
||||
self.override_stop_tokens,
|
||||
self.key_string,
|
||||
dataclasses.replace(
|
||||
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
|
||||
),
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
@@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
||||
def _from_context(
|
||||
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
|
||||
) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
||||
matcher,
|
||||
self.vocab_size,
|
||||
ctx,
|
||||
self.override_stop_tokens,
|
||||
key_string,
|
||||
grammar_stats,
|
||||
)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
@@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
return self._from_context(ctx, key_string, GrammarStats())
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
return self._from_context(ctx, key_string, GrammarStats())
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
return self._from_context(ctx, key_string, GrammarStats())
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
return self._from_context(ctx, key_string, GrammarStats())
|
||||
|
||||
def reset(self):
|
||||
self.grammar_compiler.clear_cache()
|
||||
|
||||
@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
if req.grammar is not None:
|
||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||
try:
|
||||
req.grammar.accept_token(req.output_ids[-1])
|
||||
# if it is not None, then the grammar is from a retracted request, and we should not
|
||||
# accept the token as it's already accepted
|
||||
if req.grammar.current_token is None:
|
||||
req.grammar.accept_token(req.output_ids[-1])
|
||||
except ValueError as e:
|
||||
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
||||
# This can happen if the grammar is not set correctly or the token is invalid.
|
||||
|
||||
Reference in New Issue
Block a user