Files
sglang/python/sglang/srt/constrained/base_grammar_backend.py
Lianmin Zheng 5dddb331c4 [Auto Sync] Update base_grammar_backend.py, xgrammar_backen... (20250930) (#11115)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
2025-09-30 21:50:43 -07:00

251 lines
7.9 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The baseclass of a backend for grammar-guided constrained decoding."""
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from threading import Event
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class GrammarStats:
compilation_time: Optional[float] = None
schema_count: Optional[int] = None
ebnf_size: Optional[int] = None
is_cache_hit: bool = False
is_grammar_aborted: bool = False
tree_traversal_time: List[float] = field(default_factory=list)
dispatch_type: Optional[str] = None
class BaseGrammarObject:
def __init__(self):
self._finished = False
self.grammar_stats = None
self.current_token = None
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError()
def rollback(self, k: int):
raise NotImplementedError()
def is_terminated(self):
return False
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError()
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError()
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError()
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError()
def copy(self) -> "BaseGrammarObject":
return self
@property
def finished(self):
return self._finished
@finished.setter
def finished(self, finished):
self._finished = finished
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
"""
Try to jump forward in the grammar.
Returns:
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise NotImplementedError()
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
"""
Jump forward for the grammar.
Returns:
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise NotImplementedError()
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise NotImplementedError()
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
@dataclass
class CacheEntry:
value: BaseGrammarObject
event: Event
class BaseGrammarBackend:
def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
def dispatch_fallback(
self, key_type: str, key_string: str
) -> Optional[BaseGrammarObject]:
"""
This function should not be reached in any case.
"""
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("json", key_string)
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("regex", key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("ebnf", key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
s = time.perf_counter()
key_type, key_string = key
if key_type == "json":
grammar = self.dispatch_json(key_string)
elif key_type == "regex":
grammar = self.dispatch_regex(key_string)
elif key_type == "ebnf":
grammar = self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
grammar = self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
grammar = self.dispatch_structural_pattern(key_string)
elif key_type == "structural_pattern_v2":
grammar = self.dispatch_structural_pattern_v2(key_string)
else:
grammar = self.dispatch_fallback(key_type, key_string)
if grammar is not None and grammar.grammar_stats is not None:
grammar.grammar_stats.compilation_time = time.perf_counter() - s
return grammar
def get_cached_or_future_value(
self, key: Tuple[str, str]
) -> Optional[BaseGrammarObject]:
value = self.cache.get(key)
if value:
return value.copy(), True
value = self.executor.submit(self._init_value_dispatch, key)
return value, False
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
self.cache[key] = value
def reset(self):
self.cache.clear()
GRAMMAR_BACKEND_REGISTRY = {}
def register_grammar_backend(name, init_func):
GRAMMAR_BACKEND_REGISTRY[name] = init_func
def create_grammar_backend(
server_args: ServerArgs,
tokenizer,
vocab_size: int,
eos_token_ids: Optional[set] = None,
) -> Optional[BaseGrammarBackend]:
name = server_args.grammar_backend
# Custom grammar backend has the highest priority
if name in GRAMMAR_BACKEND_REGISTRY:
return GRAMMAR_BACKEND_REGISTRY[name](
server_args, tokenizer, vocab_size, eos_token_ids
)
# Default grammar backends
if name == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif name == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
# Convert Set[int] to List[int] if needed
eos_list = list(eos_token_ids) if eos_token_ids else None
grammar_backend = XGrammarGrammarBackend(
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
)
elif name == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif name == "none":
return None
else:
raise ValueError(f"Invalid grammar backend: {name}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from sglang.srt.constrained.reasoner_grammar_backend import (
ReasonerGrammarBackend,
)
grammar_backend = ReasonerGrammarBackend(
grammar_backend, tokenizer.think_end_id
)
return grammar_backend