Fix the overlap for xgrammar (#2377)

This commit is contained in:
Lianmin Zheng
2024-12-06 05:49:29 -08:00
committed by GitHub
parent 3cde5eb629
commit 0e7409adb6
7 changed files with 145 additions and 133 deletions

View File

@@ -158,22 +158,23 @@ class SamplingBatchInfo:
return
# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar)
first_grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar is not None:
try:
grammar.fill_vocab_mask(self.vocab_mask, i)
except RuntimeError:
continue
if grammar and not grammar.finished:
grammar.fill_vocab_mask(self.vocab_mask, i)
# Move the mask to the device if needed
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)