Fix regex mask (#1296)
This commit is contained in:
@@ -63,7 +63,7 @@ class Sampler(CustomOp):
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
logits = self._apply_penalties(logits, sampling_info)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user