Minor follow-up fixes for the logprob refactor (#2670)
This commit is contained in:
@@ -244,7 +244,7 @@ class SamplingBatchInfo:
|
||||
|
||||
# repetition
|
||||
if self.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits[:] = torch.where(
|
||||
logits > 0,
|
||||
logits / self.scaling_penalties,
|
||||
logits * self.scaling_penalties,
|
||||
@@ -253,5 +253,3 @@ class SamplingBatchInfo:
|
||||
# Apply regex vocab_mask
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user