update llguidance to 0.7.11; adds StructTag (#4870)

This commit is contained in:
Michał Moskal
2025-04-26 20:13:57 -07:00
committed by GitHub
parent 9ad28f639e
commit bdbe5f816b
3 changed files with 86 additions and 63 deletions

View File

@@ -24,7 +24,7 @@ runtime_common = [
"hf_transfer", "hf_transfer",
"huggingface_hub", "huggingface_hub",
"interegular", "interegular",
"llguidance>=0.6.15", "llguidance>=0.7.11,<0.8.0",
"modelscope", "modelscope",
"ninja", "ninja",
"orjson", "orjson",

View File

@@ -14,49 +14,48 @@
"""Constrained decoding with llguidance backend.""" """Constrained decoding with llguidance backend."""
import json import json
import logging
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import llguidance
import llguidance.hf
import llguidance.torch
import torch import torch
from llguidance.gbnf_to_lark import any_to_lark from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
from llguidance.hf import from_tokenizer
from llguidance.torch import (
allocate_token_bitmask,
apply_token_bitmask_inplace,
fill_next_token_bitmask,
)
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
logger = logging.getLogger(__name__)
class GuidanceGrammar(BaseGrammarObject): class GuidanceGrammar(BaseGrammarObject):
def __init__(
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
):
super().__init__() super().__init__()
self.llguidance_tokenizer = llguidance_tokenizer self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar self.serialized_grammar = serialized_grammar
# TODO: add support for fast-forward tokens in the future self.ll_matcher = LLMatcher(
self.ll_interpreter = llguidance.LLInterpreter(
self.llguidance_tokenizer, self.llguidance_tokenizer,
self.serialized_grammar, self.serialized_grammar,
enable_backtrack=False,
enable_ff_tokens=False,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
) )
self.pending_ff_tokens: list[int] = []
self.finished = False self.finished = False
self.bitmask = None self.bitmask = None
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
if len(self.pending_ff_tokens) > 0: ff_tokens = self.ll_matcher.compute_ff_tokens()
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) if ff_tokens:
ff_tokens = self.pending_ff_tokens return ff_tokens, ""
self.pending_ff_tokens = [] else:
return (ff_tokens, s) return None
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1 return "", -1
@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject):
pass pass
def accept_token(self, token: int): def accept_token(self, token: int):
backtrack, ff_tokens = self.ll_interpreter.commit_token(token) if not self.ll_matcher.consume_token(token):
if len(ff_tokens) > 0 and backtrack == 0: logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
# first token is last generated token
ff_tokens = ff_tokens[1:]
self.pending_ff_tokens.extend(ff_tokens)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if len(self.pending_ff_tokens) > 0:
# if we have pending fast-forward tokens,
# just return them immediately
ff_token = self.pending_ff_tokens.pop(0)
vocab_mask[idx, :] = 0
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
return
if self.ll_interpreter.has_pending_stop():
self.finished = True self.finished = True
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if self.ll_matcher.is_stopped():
self.finished = True
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
) -> torch.Tensor: ) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size: if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger # only create bitmask when batch gets larger
self.bitmask = llguidance.torch.allocate_token_bitmask( self.bitmask = allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size batch_size, self.llguidance_tokenizer.vocab_size
) )
bitmask = self.bitmask bitmask = self.bitmask
@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject):
@staticmethod @staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask) apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
return GuidanceGrammar( return GuidanceGrammar(
@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject):
class GuidanceBackend(BaseGrammarBackend): class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
def __init__(
self,
tokenizer,
whitespace_pattern: Optional[str] = None,
n_vocab: Optional[int] = None,
):
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.whitespace_flexible = ( self.whitespace_pattern = whitespace_pattern
True if whitespace_pattern == "whitespace_flexible" else False self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
try:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
return None
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
) )
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
json_schema = key_string
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
serialized_grammar = compiler.compile(json_schema)
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> GuidanceGrammar: def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
compiler = llguidance.RegexCompiler() serialized_grammar = grammar_from("regex", key_string)
serialized_grammar = compiler.compile(regex=key_string)
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar: def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
compiler = llguidance.LarkCompiler() try:
serialized_grammar = compiler.compile(any_to_lark(key_string)) serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
return None
def dispatch_structural_tag(self, key_string: str): def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
return super().dispatch_structural_tag(key_string) try:
structural_tag = json.loads(key_string)
tags = [
StructTag(
begin=structure["begin"],
grammar=structure["schema"],
end=structure["end"],
trigger=structural_tag["triggers"][0], # TODO?
)
for structure in structural_tag["structures"]
]
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
return None

View File

@@ -238,5 +238,11 @@ class TestEBNFConstrained(CustomTestCase):
) )
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()