Fix the overlap for xgrammar (#2377)

This commit is contained in:
Lianmin Zheng
2024-12-06 05:49:29 -08:00
committed by GitHub
parent 3cde5eb629
commit 0e7409adb6
7 changed files with 145 additions and 133 deletions

View File

@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.finished = False
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if vocab_mask.device.type != logits.device.type:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):