[Feature] Support llguidance for constrained decoding (#3298)
This commit is contained in:
@@ -38,6 +38,7 @@ runtime_common = [
|
||||
"xgrammar==0.1.10",
|
||||
"ninja",
|
||||
"transformers==4.48.3",
|
||||
"llguidance>=0.6.15"
|
||||
]
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
|
||||
@@ -86,6 +86,13 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||
|
||||
grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
|
||||
elif server_args.grammar_backend == "llguidance":
|
||||
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
||||
|
||||
grammar_backend = GuidanceBackend(
|
||||
tokenizer=tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
|
||||
|
||||
146
python/sglang/srt/constrained/llguidance_backend.py
Normal file
146
python/sglang/srt/constrained/llguidance_backend.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with llguidance backend."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import llguidance
|
||||
import llguidance.hf
|
||||
import llguidance.torch
|
||||
import torch
|
||||
from llguidance.gbnf_to_lark import any_to_lark
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
|
||||
|
||||
class GuidanceGrammar(BaseGrammarObject):
|
||||
def __init__(
|
||||
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
|
||||
):
|
||||
self.llguidance_tokenizer = llguidance_tokenizer
|
||||
self.serialized_grammar = serialized_grammar
|
||||
|
||||
# TODO: add support for fast-forward tokens in the future
|
||||
self.ll_interpreter = llguidance.LLInterpreter(
|
||||
self.llguidance_tokenizer,
|
||||
self.serialized_grammar,
|
||||
enable_backtrack=False,
|
||||
enable_ff_tokens=False,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
self.pending_ff_tokens: list[int] = []
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||
if len(self.pending_ff_tokens) > 0:
|
||||
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
|
||||
ff_tokens = self.pending_ff_tokens
|
||||
self.pending_ff_tokens = []
|
||||
return (ff_tokens, s)
|
||||
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
def accept_token(self, token: int):
|
||||
backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
|
||||
if len(ff_tokens) > 0 and backtrack == 0:
|
||||
# first token is last generated token
|
||||
ff_tokens = ff_tokens[1:]
|
||||
self.pending_ff_tokens.extend(ff_tokens)
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
if len(self.pending_ff_tokens) > 0:
|
||||
# if we have pending fast-forward tokens,
|
||||
# just return them immediately
|
||||
ff_token = self.pending_ff_tokens.pop(0)
|
||||
vocab_mask[idx, :] = 0
|
||||
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
|
||||
return
|
||||
|
||||
if self.ll_interpreter.has_pending_stop():
|
||||
self.finished = True
|
||||
|
||||
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
|
||||
# only create bitmask when batch gets larger
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
batch_size, self.llguidance_tokenizer.vocab_size
|
||||
)
|
||||
bitmask = self.bitmask
|
||||
else:
|
||||
bitmask = self.bitmask[:batch_size]
|
||||
|
||||
return bitmask
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask.to(device, non_blocking=True)
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
return GuidanceGrammar(
|
||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||
serialized_grammar=self.serialized_grammar,
|
||||
)
|
||||
|
||||
|
||||
class GuidanceBackend(BaseGrammarBackend):
|
||||
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.whitespace_flexible = (
|
||||
True if whitespace_pattern == "whitespace_flexible" else False
|
||||
)
|
||||
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
|
||||
mode, value = key
|
||||
if mode == "json":
|
||||
json_schema = value
|
||||
compiler = llguidance.JsonCompiler(
|
||||
whitespace_flexible=self.whitespace_flexible
|
||||
)
|
||||
serialized_grammar = compiler.compile(json_schema)
|
||||
elif mode == "regex":
|
||||
compiler = llguidance.RegexCompiler()
|
||||
serialized_grammar = compiler.compile(regex=value)
|
||||
elif mode == "ebnf":
|
||||
compiler = llguidance.LarkCompiler()
|
||||
serialized_grammar = compiler.compile(any_to_lark(value))
|
||||
|
||||
return GuidanceGrammar(
|
||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||
serialized_grammar=serialized_grammar,
|
||||
)
|
||||
@@ -698,7 +698,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--grammar-backend",
|
||||
type=str,
|
||||
choices=["xgrammar", "outlines"],
|
||||
choices=["xgrammar", "outlines", "llguidance"],
|
||||
default=ServerArgs.grammar_backend,
|
||||
help="Choose the backend for grammar-guided decoding.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user