[Performance] Update xgrammar-related constrained decoding (#2056)
This commit is contained in:
@@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject):
|
|||||||
):
|
):
|
||||||
self.state = next_state
|
self.state = next_state
|
||||||
|
|
||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
def allocate_vocab_mask(
|
||||||
|
self, vocab_size: int, batch_size: int, device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||||
|
vocab_mask = vocab_mask[idx]
|
||||||
vocab_mask.fill_(1)
|
vocab_mask.fill_(1)
|
||||||
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||||
|
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,12 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
from xgrammar import (
|
||||||
|
CachedGrammarCompiler,
|
||||||
|
CompiledGrammar,
|
||||||
|
GrammarMatcher,
|
||||||
|
TokenizerInfo,
|
||||||
|
)
|
||||||
|
|
||||||
import_error = None
|
import_error = None
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
for i in range(k, len(new_output_ids)):
|
for i in range(k, len(new_output_ids)):
|
||||||
assert self.matcher.accept_token(new_output_ids[i])
|
assert self.matcher.accept_token(new_output_ids[i])
|
||||||
|
|
||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
def allocate_vocab_mask(
|
||||||
# Note that this bitmask is a bitset, not bool
|
self, vocab_size: int, batch_size: int, device
|
||||||
bitmask = self.matcher.get_next_token_bitmask()
|
) -> torch.Tensor:
|
||||||
# Mask the tokens that are not allowed
|
return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
|
||||||
vocab_mask[
|
|
||||||
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||||
] = 1
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||||
|
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
matcher = GrammarMatcher(
|
matcher = GrammarMatcher(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
mask_vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||||
|
|
||||||
@@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
self.grammar_cache = None
|
self.grammar_cache = None
|
||||||
return
|
return
|
||||||
|
|
||||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
|
||||||
|
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
@@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
try:
|
try:
|
||||||
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
|
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
|
||||||
key_string
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||||
@@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
matcher = GrammarMatcher(
|
matcher = GrammarMatcher(
|
||||||
ctx,
|
ctx,
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
mask_vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||||
|
|
||||||
|
|||||||
@@ -645,7 +645,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Apply regex vocab_mask
|
# Apply regex vocab_mask
|
||||||
if sampling_info.vocab_mask is not None:
|
if sampling_info.vocab_mask is not None:
|
||||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ class SamplingBatchInfo:
|
|||||||
vocab_size: int
|
vocab_size: int
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
|
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
@@ -135,17 +135,23 @@ class SamplingBatchInfo:
|
|||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
if not self.grammars or not any(grammar for grammar in self.grammars):
|
if not self.grammars or not any(grammar for grammar in self.grammars):
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
self.apply_mask = None
|
||||||
return
|
return
|
||||||
|
|
||||||
self.vocab_mask = torch.zeros(
|
# find a grammar from the list
|
||||||
len(self.temperatures),
|
grammar = next(grammar for grammar in self.grammars if grammar is not None)
|
||||||
self.vocab_size,
|
|
||||||
dtype=torch.bool,
|
# maybe we can reuse the existing mask?
|
||||||
|
self.vocab_mask = grammar.allocate_vocab_mask(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
batch_size=len(self.temperatures),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
|
||||||
|
|
||||||
for i, grammar in enumerate(self.grammars):
|
for i, grammar in enumerate(self.grammars):
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
grammar.fill_vocab_mask(self.vocab_mask[i])
|
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
if self.penalizer_orchestrator:
|
if self.penalizer_orchestrator:
|
||||||
|
|||||||
Reference in New Issue
Block a user