Fix grammar backend for tensor parallelism (#2020)

This commit is contained in:
Lianmin Zheng
2024-11-13 01:49:45 -08:00
committed by GitHub
parent ba069a24d3
commit 54479d6f30
7 changed files with 250 additions and 328 deletions

View File

@@ -15,38 +15,36 @@ limitations under the License.
"""Constrained decoding with xgrammar backend."""
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple
import torch
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
import_error = None
except ImportError as e:
import_error = e
class Dummy:
pass
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)
MAX_ROLLBACK_TOKENS = 10
class XGrammarGrammar:
class XGrammarGrammar(BaseGrammarObject):
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
return [], self.matcher.find_jump_forward_string()
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
@@ -77,51 +75,40 @@ class XGrammarGrammar:
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
] = 1
def copy(self):
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size,
)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
class XGrammarGrammarBackend:
class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
vocab_size: int,
):
if import_error:
raise import_error
self.executor = ThreadPoolExecutor()
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
self.vocab_size = vocab_size
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
def query(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._query, key)
def reset(self):
self.grammar_cache.reset()
class XGrammarCache:
def __init__(self, tokenizer, vocab_size: int):
super().__init__()
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
self.vocab_size = vocab_size
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError("regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
matcher = GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size,
)
return XGrammarGrammar(matcher, self.vocab_size, ctx)
def reset(self):
self.grammar_cache.clear()