Fix regex mask (#1296)
This commit is contained in:
@@ -63,7 +63,7 @@ class Sampler(CustomOp):
|
|||||||
logits.add_(sampling_info.logit_bias)
|
logits.add_(sampling_info.logit_bias)
|
||||||
|
|
||||||
if sampling_info.vocab_mask is not None:
|
if sampling_info.vocab_mask is not None:
|
||||||
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||||
|
|
||||||
logits = self._apply_penalties(logits, sampling_info)
|
logits = self._apply_penalties(logits, sampling_info)
|
||||||
|
|
||||||
|
|||||||
@@ -154,15 +154,15 @@ class SamplingBatchInfo:
|
|||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
|
||||||
if has_regex:
|
if has_regex:
|
||||||
|
self.vocab_mask = torch.zeros(
|
||||||
|
bs, self.vocab_size, dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
if req.regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
if self.vocab_mask is None:
|
self.vocab_mask[i].fill_(1)
|
||||||
self.vocab_mask = torch.zeros(
|
|
||||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
|
||||||
)
|
|
||||||
self.vocab_mask[i][
|
self.vocab_mask[i][
|
||||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||||
] = 1
|
] = 0
|
||||||
|
|
||||||
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||||
|
|||||||
Reference in New Issue
Block a user