[Auto Sync] Update base_grammar_backend.py, llguidance_back... (20250911) (#10333)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,8 +14,9 @@
|
|||||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -26,10 +27,22 @@ from sglang.srt.server_args import ServerArgs
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GrammarStats:
|
||||||
|
compilation_time: Optional[float] = None
|
||||||
|
schema_count: Optional[int] = None
|
||||||
|
ebnf_size: Optional[int] = None
|
||||||
|
is_cache_hit: bool = False
|
||||||
|
is_grammar_aborted: bool = False
|
||||||
|
tree_traversal_time: List[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BaseGrammarObject:
|
class BaseGrammarObject:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._finished = False
|
self._finished = False
|
||||||
|
self.grammar_stats = None
|
||||||
|
self.current_token = None
|
||||||
|
|
||||||
def accept_token(self, token: int) -> None:
|
def accept_token(self, token: int) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -137,19 +150,26 @@ class BaseGrammarBackend:
|
|||||||
return self._not_supported("structural_tag", key_string)
|
return self._not_supported("structural_tag", key_string)
|
||||||
|
|
||||||
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
|
s = time.perf_counter()
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
return self.dispatch_json(key_string)
|
grammar = self.dispatch_json(key_string)
|
||||||
elif key_type == "regex":
|
elif key_type == "regex":
|
||||||
return self.dispatch_regex(key_string)
|
grammar = self.dispatch_regex(key_string)
|
||||||
elif key_type == "ebnf":
|
elif key_type == "ebnf":
|
||||||
return self.dispatch_ebnf(key_string)
|
grammar = self.dispatch_ebnf(key_string)
|
||||||
elif key_type == "structural_tag":
|
elif key_type == "structural_tag":
|
||||||
return self.dispatch_structural_tag(key_string)
|
grammar = self.dispatch_structural_tag(key_string)
|
||||||
elif key_type == "structural_pattern":
|
elif key_type == "structural_pattern":
|
||||||
return self.dispatch_structural_pattern(key_string)
|
grammar = self.dispatch_structural_pattern(key_string)
|
||||||
|
elif key_type == "structural_pattern_v2":
|
||||||
|
grammar = self.dispatch_structural_pattern_v2(key_string)
|
||||||
else:
|
else:
|
||||||
return self.dispatch_fallback(key_type, key_string)
|
grammar = self.dispatch_fallback(key_type, key_string)
|
||||||
|
|
||||||
|
if grammar is not None and grammar.grammar_stats is not None:
|
||||||
|
grammar.grammar_stats.compilation_time = time.perf_counter() - s
|
||||||
|
return grammar
|
||||||
|
|
||||||
def get_cached_or_future_value(
|
def get_cached_or_future_value(
|
||||||
self, key: Tuple[str, str]
|
self, key: Tuple[str, str]
|
||||||
@@ -167,20 +187,36 @@ class BaseGrammarBackend:
|
|||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
GRAMMAR_BACKEND_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_grammar_backend(name, init_func):
|
||||||
|
GRAMMAR_BACKEND_REGISTRY[name] = init_func
|
||||||
|
|
||||||
|
|
||||||
def create_grammar_backend(
|
def create_grammar_backend(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
eos_token_ids: Optional[set] = None,
|
eos_token_ids: Optional[set] = None,
|
||||||
) -> Optional[BaseGrammarBackend]:
|
) -> Optional[BaseGrammarBackend]:
|
||||||
if server_args.grammar_backend == "outlines":
|
name = server_args.grammar_backend
|
||||||
|
|
||||||
|
# Custom grammar backend has the highest priority
|
||||||
|
if name in GRAMMAR_BACKEND_REGISTRY:
|
||||||
|
return GRAMMAR_BACKEND_REGISTRY[name](
|
||||||
|
server_args, tokenizer, vocab_size, eos_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default grammar backends
|
||||||
|
if name == "outlines":
|
||||||
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
||||||
|
|
||||||
grammar_backend = OutlinesGrammarBackend(
|
grammar_backend = OutlinesGrammarBackend(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
)
|
)
|
||||||
elif server_args.grammar_backend == "xgrammar":
|
elif name == "xgrammar":
|
||||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||||
|
|
||||||
# Convert Set[int] to List[int] if needed
|
# Convert Set[int] to List[int] if needed
|
||||||
@@ -189,17 +225,17 @@ def create_grammar_backend(
|
|||||||
grammar_backend = XGrammarGrammarBackend(
|
grammar_backend = XGrammarGrammarBackend(
|
||||||
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
||||||
)
|
)
|
||||||
elif server_args.grammar_backend == "llguidance":
|
elif name == "llguidance":
|
||||||
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
||||||
|
|
||||||
grammar_backend = GuidanceBackend(
|
grammar_backend = GuidanceBackend(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
)
|
)
|
||||||
elif server_args.grammar_backend == "none":
|
elif name == "none":
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
raise ValueError(f"Invalid grammar backend: {name}")
|
||||||
|
|
||||||
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
||||||
from sglang.srt.constrained.reasoner_grammar_backend import (
|
from sglang.srt.constrained.reasoner_grammar_backend import (
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
|||||||
self.serialized_grammar,
|
self.serialized_grammar,
|
||||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||||
)
|
)
|
||||||
self.finished = False
|
|
||||||
self.bitmask = None
|
self.bitmask = None
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
|||||||
self.guide = guide
|
self.guide = guide
|
||||||
self.jump_forward_map = jump_forward_map
|
self.jump_forward_map = jump_forward_map
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.finished = False
|
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
self.state = self.guide.get_next_state(self.state, token)
|
self.state = self.guide.get_next_state(self.state, token)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Constrained decoding with xgrammar backend."""
|
"""Constrained decoding with xgrammar backend."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|||||||
INVALID_GRAMMAR_OBJ,
|
INVALID_GRAMMAR_OBJ,
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
|
GrammarStats,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
@@ -41,9 +43,9 @@ else:
|
|||||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||||
apply_token_bitmask_inplace_triton,
|
apply_token_bitmask_inplace_triton,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MAX_ROLLBACK_TOKENS = 200
|
MAX_ROLLBACK_TOKENS = 200
|
||||||
|
|
||||||
|
|
||||||
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
ctx: CompiledGrammar,
|
ctx: CompiledGrammar,
|
||||||
override_stop_tokens: Optional[Union[List[int], int]],
|
override_stop_tokens: Optional[Union[List[int], int]],
|
||||||
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
||||||
|
grammar_stats: Optional[GrammarStats] = GrammarStats(),
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.override_stop_tokens = override_stop_tokens
|
self.override_stop_tokens = override_stop_tokens
|
||||||
self.finished = False
|
|
||||||
self.accepted_tokens = []
|
self.accepted_tokens = []
|
||||||
self.key_string = key_string
|
self.key_string = key_string
|
||||||
|
self.grammar_stats = grammar_stats
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
if not self.is_terminated():
|
if not self.is_terminated():
|
||||||
|
self.current_token = token
|
||||||
accepted = self.matcher.accept_token(token)
|
accepted = self.matcher.accept_token(token)
|
||||||
if not accepted:
|
if not accepted:
|
||||||
# log for debugging
|
# log for debugging
|
||||||
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
self.ctx,
|
self.ctx,
|
||||||
self.override_stop_tokens,
|
self.override_stop_tokens,
|
||||||
self.key_string,
|
self.key_string,
|
||||||
|
dataclasses.replace(
|
||||||
|
self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||||
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
assert self.matcher.accept_token(new_output_ids[i])
|
assert self.matcher.accept_token(new_output_ids[i])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
|
||||||
|
|
||||||
|
|
||||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||||
@@ -177,14 +185,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.override_stop_tokens = override_stop_tokens
|
self.override_stop_tokens = override_stop_tokens
|
||||||
|
|
||||||
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
def _from_context(
|
||||||
|
self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
|
||||||
|
) -> XGrammarGrammar:
|
||||||
matcher = GrammarMatcher(
|
matcher = GrammarMatcher(
|
||||||
ctx,
|
ctx,
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
override_stop_tokens=self.override_stop_tokens,
|
override_stop_tokens=self.override_stop_tokens,
|
||||||
)
|
)
|
||||||
return XGrammarGrammar(
|
return XGrammarGrammar(
|
||||||
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
matcher,
|
||||||
|
self.vocab_size,
|
||||||
|
ctx,
|
||||||
|
self.override_stop_tokens,
|
||||||
|
key_string,
|
||||||
|
grammar_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
@@ -198,7 +213,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||||
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||||
return INVALID_GRAMMAR_OBJ
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string, GrammarStats())
|
||||||
|
|
||||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
@@ -206,7 +221,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||||
return INVALID_GRAMMAR_OBJ
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string, GrammarStats())
|
||||||
|
|
||||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
@@ -214,7 +229,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
||||||
return INVALID_GRAMMAR_OBJ
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string, GrammarStats())
|
||||||
|
|
||||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
@@ -233,7 +248,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||||
return INVALID_GRAMMAR_OBJ
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string, GrammarStats())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.grammar_compiler.clear_cache()
|
self.grammar_compiler.clear_cache()
|
||||||
|
|||||||
@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||||
try:
|
try:
|
||||||
req.grammar.accept_token(req.output_ids[-1])
|
# if it is not None, then the grammar is from a retracted request, and we should not
|
||||||
|
# accept the token as it's already accepted
|
||||||
|
if req.grammar.current_token is None:
|
||||||
|
req.grammar.accept_token(req.output_ids[-1])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
||||||
# This can happen if the grammar is not set correctly or the token is invalid.
|
# This can happen if the grammar is not set correctly or the token is invalid.
|
||||||
|
|||||||
Reference in New Issue
Block a user