diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 5412caa22..5d1b0da29 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -13,31 +13,130 @@ # ============================================================================== """The baseclass of a backend for grammar-guided constrained decoding.""" +import logging +from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from threading import Event, Lock -from typing import Any, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +import torch from sglang.srt.server_args import ServerArgs +logger = logging.getLogger(__name__) + + +class BaseGrammarObject(ABC): + @abstractmethod + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + """ + Try to jump forward in the grammar. + + Returns: + A jump forward helper which may be used in `jump_forward_str_state`. + None if the jump forward is not possible. + """ + raise NotImplementedError + + @abstractmethod + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + """ + Jump forward for the grammar. + + Returns: + A tuple of the jump forward string and the next state of the grammar + (which can be used in `jump_and_retokenize` if needed). + """ + raise NotImplementedError + + @abstractmethod + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ) -> None: + """ + Jump forward occurs, and update the grammar state if needed. + """ + raise NotImplementedError + + @abstractmethod + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + @abstractmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + raise NotImplementedError + + @abstractmethod + def copy(self) -> "BaseGrammarObject": + raise NotImplementedError + @dataclass class CacheEntry: - value: Any + value: Optional[BaseGrammarObject] event: Event -class BaseGrammarObject: - pass - - -class BaseGrammarBackend: +class BaseGrammarBackend(ABC): def __init__(self): self.executor = ThreadPoolExecutor() - self.cache = {} + self.cache: Dict[Tuple[str, str], CacheEntry] = {} self.cache_lock = Lock() - def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject: + def _not_supported(self, key_type: str, key_string: str) -> None: + logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}") + + def dispatch_fallback( + self, key_type: str, key_string: str + ) -> Optional[BaseGrammarObject]: + """ + This function should not be reached in any case. + """ + raise ValueError(f"Invalid key_type: {key_type}={key_string}") + + @abstractmethod + def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("json", key_string) + + @abstractmethod + def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("regex", key_string) + + @abstractmethod + def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("ebnf", key_string) + + @abstractmethod + def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]: + return self._not_supported("structural_tag", key_string) + + def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: + key_type, key_string = key + if key_type == "json": + return self.dispatch_json(key_string) + elif key_type == "regex": + return self.dispatch_regex(key_string) + elif key_type == "ebnf": + return self.dispatch_ebnf(key_string) + elif key_type == "structural_tag": + return self.dispatch_structural_tag(key_string) + else: + return self.dispatch_fallback(key_type, key_string) + + def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: with self.cache_lock: if key in self.cache: cache_hit = True @@ -50,13 +149,10 @@ class BaseGrammarBackend: if cache_hit: entry.event.wait() else: - entry.value = self.init_value_impl(key) + entry.value = self._init_value_dispatch(key) entry.event.set() return entry.value.copy() if entry.value else None - def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject: - raise NotImplementedError() - def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: with self.cache_lock: entry = self.cache.get(key) @@ -66,7 +162,7 @@ class BaseGrammarBackend: return val.copy() if val else None def get_future_value(self, key: Tuple[str, str]) -> Future: - return self.executor.submit(self.init_value, key) + return self.executor.submit(self._init_value, key) def reset(self): with self.cache_lock: diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py index 5d2b69790..24893a49d 100644 --- a/python/sglang/srt/constrained/llguidance_backend.py +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject): self.finished = False self.bitmask = None - def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: if len(self.pending_ff_tokens) > 0: s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) ff_tokens = self.pending_ff_tokens @@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend): ) self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) - def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar: - mode, value = key - if mode == "json": - json_schema = value - compiler = llguidance.JsonCompiler( - whitespace_flexible=self.whitespace_flexible - ) - serialized_grammar = compiler.compile(json_schema) - elif mode == "regex": - compiler = llguidance.RegexCompiler() - serialized_grammar = compiler.compile(regex=value) - elif mode == "ebnf": - compiler = llguidance.LarkCompiler() - serialized_grammar = compiler.compile(any_to_lark(value)) - + def _from_serialized(self, serialized_grammar) -> GuidanceGrammar: return GuidanceGrammar( llguidance_tokenizer=self.llguidance_tokenizer, serialized_grammar=serialized_grammar, ) + + def dispatch_json(self, key_string: str) -> GuidanceGrammar: + json_schema = key_string + compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible) + serialized_grammar = compiler.compile(json_schema) + return self._from_serialized(serialized_grammar) + + def dispatch_regex(self, key_string: str) -> GuidanceGrammar: + compiler = llguidance.RegexCompiler() + serialized_grammar = compiler.compile(regex=key_string) + return self._from_serialized(serialized_grammar) + + def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar: + compiler = llguidance.LarkCompiler() + serialized_grammar = compiler.compile(any_to_lark(key_string)) + return self._from_serialized(serialized_grammar) + + def dispatch_structural_tag(self, key_string: str): + return super().dispatch_structural_tag(key_string) diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 1cf46bd01..7a05bb2c5 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ) self.whitespace_pattern = whitespace_pattern - def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar: - key_type, key_string = key - if key_type == "json": - try: - regex = build_regex_from_object( - key_string, - whitespace_pattern=self.whitespace_pattern, - ) - except (NotImplementedError, json.decoder.JSONDecodeError) as e: - logger.warning( - f"Skip invalid json_schema: json_schema={key_string}, {e=}" - ) - return None - elif key_type == "regex": - regex = key_string - else: - raise ValueError(f"Invalid key_type: {key_type}") - + def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]: try: if hasattr(RegexGuide, "from_regex"): # outlines >= 0.1.1 @@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend): jump_forward_map = None return OutlinesGrammar(guide, jump_forward_map) + def dispatch_ebnf(self, key_string: str): + return super().dispatch_ebnf(key_string) + + def dispatch_structural_tag(self, key_string: str): + return super().dispatch_structural_tag(key_string) + + def dispatch_json(self, key_string: str): + try: + regex = build_regex_from_object( + key_string, + whitespace_pattern=self.whitespace_pattern, + ) + except (NotImplementedError, json.decoder.JSONDecodeError) as e: + logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") + return self._compile_regex(regex) + + def dispatch_regex(self, key_string: str): + return self._compile_regex(key_string) + def build_regex_from_object( object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index b7f9a15e9..4df3ae286 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject): def accept_token(self, token: int): assert self.matcher.accept_token(token) - def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: s = self.matcher.find_jump_forward_string() if s: return [], s @@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend): self.vocab_size = vocab_size self.override_stop_tokens = override_stop_tokens - def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: - - key_type, key_string = key - if key_type == "json": - try: - if key_string == "$$ANY$$": - ctx = self.grammar_compiler.compile_builtin_json_grammar() - else: - ctx = self.grammar_compiler.compile_json_schema(schema=key_string) - except RuntimeError as e: - logging.warning( - f"Skip invalid json_schema: json_schema={key_string}, {e=}" - ) - return None - elif key_type == "ebnf": - try: - ctx = self.grammar_compiler.compile_grammar(key_string) - except RuntimeError as e: - logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") - return None - elif key_type == "regex": - try: - ctx = self.grammar_compiler.compile_regex(key_string) - except RuntimeError as e: - logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") - return None - elif key_type == "structural_tag": - try: - structural_tag = json.loads(key_string) - tags = [ - StructuralTagItem( - begin=structure["begin"], - schema=json.dumps(structure["schema"]), - end=structure["end"], - ) - for structure in structural_tag["structures"] - ] - ctx = self.grammar_compiler.compile_structural_tag( - tags, structural_tag["triggers"] - ) - except RuntimeError as e: - logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") - return None - else: - raise ValueError(f"Invalid key_type: {key_type}") - + def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar: matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens) + def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + if key_string == "$$ANY$$": + ctx = self.grammar_compiler.compile_builtin_json_grammar() + else: + ctx = self.grammar_compiler.compile_json_schema(schema=key_string) + except RuntimeError as e: + logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") + return None + return self._from_context(ctx) + + def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_grammar(key_string) + except RuntimeError as e: + logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") + return None + return self._from_context(ctx) + + def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + ctx = self.grammar_compiler.compile_regex(key_string) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None + return self._from_context(ctx) + + def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: + try: + structural_tag = json.loads(key_string) + tags = [ + StructuralTagItem( + begin=structure["begin"], + schema=json.dumps(structure["schema"]), + end=structure["end"], + ) + for structure in structural_tag["structures"] + ] + ctx = self.grammar_compiler.compile_structural_tag( + tags, structural_tag["triggers"] + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None + return self._from_context(ctx) + def reset(self): if self.grammar_compiler: self.grammar_compiler.clear_cache()