[Performance] Update xgrammar-related constrained decoding (#2056)

This commit is contained in:
DarkSharpness
2024-11-18 09:58:49 +09:00
committed by GitHub
parent ebaa2f3199
commit 9c745d078e
4 changed files with 47 additions and 23 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
@@ -29,7 +29,7 @@ class SamplingBatchInfo:
vocab_size: int
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
grammars: Optional[List] = None
# Penalizer
@@ -135,17 +135,23 @@ class SamplingBatchInfo:
def update_regex_vocab_mask(self):
if not self.grammars or not any(grammar for grammar in self.grammars):
self.vocab_mask = None
self.apply_mask = None
return
self.vocab_mask = torch.zeros(
len(self.temperatures),
self.vocab_size,
dtype=torch.bool,
# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar is not None)
# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
for i, grammar in enumerate(self.grammars):
if grammar is not None:
grammar.fill_vocab_mask(self.vocab_mask[i])
grammar.fill_vocab_mask(self.vocab_mask, i)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: