Fix the overlap for xgrammar (#2377)
This commit is contained in:
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
self.guide = guide
|
||||
self.jump_forward_map = jump_forward_map
|
||||
self.state = 0
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
|
||||
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
if vocab_mask.device.type != logits.device.type:
|
||||
# vocab_mask must then be on the same device as logits
|
||||
# when applying the token bitmask, so we check and move if needed
|
||||
vocab_mask = vocab_mask.to(logits.device)
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask.to(device, non_blocking=True)
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
|
||||
Reference in New Issue
Block a user