diff --git a/python/pyproject.toml b/python/pyproject.toml index e487dba1a..350f085e4 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -24,7 +24,7 @@ runtime_common = [ "hf_transfer", "huggingface_hub", "interegular", - "llguidance>=0.6.15", + "llguidance>=0.7.11,<0.8.0", "modelscope", "ninja", "orjson", diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py index 6926e1c30..49c3740fb 100644 --- a/python/sglang/srt/constrained/llguidance_backend.py +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -14,49 +14,48 @@ """Constrained decoding with llguidance backend.""" import json +import logging import os from typing import List, Optional, Tuple -import llguidance -import llguidance.hf -import llguidance.torch import torch -from llguidance.gbnf_to_lark import any_to_lark +from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from +from llguidance.hf import from_tokenizer +from llguidance.torch import ( + allocate_token_bitmask, + apply_token_bitmask_inplace, + fill_next_token_bitmask, +) from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, BaseGrammarObject, ) +logger = logging.getLogger(__name__) + class GuidanceGrammar(BaseGrammarObject): - def __init__( - self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str - ): + + def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str): super().__init__() self.llguidance_tokenizer = llguidance_tokenizer self.serialized_grammar = serialized_grammar - # TODO: add support for fast-forward tokens in the future - self.ll_interpreter = llguidance.LLInterpreter( + self.ll_matcher = LLMatcher( self.llguidance_tokenizer, self.serialized_grammar, - enable_backtrack=False, - enable_ff_tokens=False, log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), ) - self.pending_ff_tokens: list[int] = [] self.finished = False self.bitmask = None def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: - if len(self.pending_ff_tokens) > 0: - s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) - ff_tokens = self.pending_ff_tokens - self.pending_ff_tokens = [] - return (ff_tokens, s) - - return None + ff_tokens = self.ll_matcher.compute_ff_tokens() + if ff_tokens: + return ff_tokens, "" + else: + return None def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: return "", -1 @@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject): pass def accept_token(self, token: int): - backtrack, ff_tokens = self.ll_interpreter.commit_token(token) - if len(ff_tokens) > 0 and backtrack == 0: - # first token is last generated token - ff_tokens = ff_tokens[1:] - self.pending_ff_tokens.extend(ff_tokens) - - def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: - if len(self.pending_ff_tokens) > 0: - # if we have pending fast-forward tokens, - # just return them immediately - ff_token = self.pending_ff_tokens.pop(0) - vocab_mask[idx, :] = 0 - vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32) - return - - if self.ll_interpreter.has_pending_stop(): + if not self.ll_matcher.consume_token(token): + logger.warning(f"matcher error: {self.ll_matcher.get_error()}") self.finished = True - llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx) + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + if self.ll_matcher.is_stopped(): + self.finished = True + + fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx) def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: if self.bitmask is None or self.bitmask.shape[0] < batch_size: # only create bitmask when batch gets larger - self.bitmask = llguidance.torch.allocate_token_bitmask( + self.bitmask = allocate_token_bitmask( batch_size, self.llguidance_tokenizer.vocab_size ) bitmask = self.bitmask @@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject): @staticmethod def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask) + apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): return GuidanceGrammar( @@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject): class GuidanceBackend(BaseGrammarBackend): - def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None): + + def __init__( + self, + tokenizer, + whitespace_pattern: Optional[str] = None, + n_vocab: Optional[int] = None, + ): super().__init__() self.tokenizer = tokenizer - self.whitespace_flexible = ( - True if whitespace_pattern == "whitespace_flexible" else False + self.whitespace_pattern = whitespace_pattern + self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab) + + def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]: + try: + return GuidanceGrammar( + llguidance_tokenizer=self.llguidance_tokenizer, + serialized_grammar=serialized_grammar, + ) + except Exception as e: + logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}") + return None + + def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]: + serialized_grammar = LLMatcher.grammar_from_json_schema( + key_string, + defaults={ + "whitespace_pattern": self.whitespace_pattern, + }, ) - self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) - - def _from_serialized(self, serialized_grammar) -> GuidanceGrammar: - return GuidanceGrammar( - llguidance_tokenizer=self.llguidance_tokenizer, - serialized_grammar=serialized_grammar, - ) - - def dispatch_json(self, key_string: str) -> GuidanceGrammar: - json_schema = key_string - compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible) - serialized_grammar = compiler.compile(json_schema) return self._from_serialized(serialized_grammar) - def dispatch_regex(self, key_string: str) -> GuidanceGrammar: - compiler = llguidance.RegexCompiler() - serialized_grammar = compiler.compile(regex=key_string) + def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]: + serialized_grammar = grammar_from("regex", key_string) return self._from_serialized(serialized_grammar) - def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar: - compiler = llguidance.LarkCompiler() - serialized_grammar = compiler.compile(any_to_lark(key_string)) - return self._from_serialized(serialized_grammar) + def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]: + try: + serialized_grammar = grammar_from("ebnf", key_string) + return self._from_serialized(serialized_grammar) + except ValueError as e: + logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}") + return None - def dispatch_structural_tag(self, key_string: str): - return super().dispatch_structural_tag(key_string) + def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]: + try: + structural_tag = json.loads(key_string) + tags = [ + StructTag( + begin=structure["begin"], + grammar=structure["schema"], + end=structure["end"], + trigger=structural_tag["triggers"][0], # TODO? + ) + for structure in structural_tag["structures"] + ] + g = StructTag.to_grammar(tags) + return self._from_serialized(g) + except Exception as e: + logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}") + return None diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py index e0771b97d..d8792f98b 100644 --- a/test/srt/test_ebnf_constrained.py +++ b/test/srt/test_ebnf_constrained.py @@ -238,5 +238,11 @@ class TestEBNFConstrained(CustomTestCase): ) +class TestEBNFConstrainedLLGuidance(TestEBNFConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, "llguidance", disable_overlap=False) + + if __name__ == "__main__": unittest.main()