Minor follow-up fixes for the logprob refactor (#2670)

This commit is contained in:
Lianmin Zheng
2024-12-30 05:42:08 -08:00
committed by GitHub
parent c5210dfa38
commit 21ec66e59e
5 changed files with 11 additions and 12 deletions

View File

@@ -244,7 +244,7 @@ class SamplingBatchInfo:
# repetition
if self.scaling_penalties is not None:
logits = torch.where(
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
@@ -253,5 +253,3 @@ class SamplingBatchInfo:
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
return logits