Fix repetition penalty (#3139)
This commit is contained in:
@@ -67,6 +67,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
|
||||
Reference in New Issue
Block a user