Fix regex mask (#1296)

This commit is contained in:
Liangsheng Yin
2024-09-01 21:50:58 -07:00
committed by GitHub
parent 4a9f8ea43b
commit 47f20da223
2 changed files with 6 additions and 6 deletions

View File

@@ -63,7 +63,7 @@ class Sampler(CustomOp):
logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info)