Fix regex mask (#1296)

This commit is contained in:
Liangsheng Yin
2024-09-01 21:50:58 -07:00
committed by GitHub
parent 4a9f8ea43b
commit 47f20da223
2 changed files with 6 additions and 6 deletions

View File

@@ -154,15 +154,15 @@ class SamplingBatchInfo:
self.vocab_mask = None
if has_regex:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
)
for i, req in enumerate(reqs):
if req.regex_fsm is not None:
if self.vocab_mask is None:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
)
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1
] = 0
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)