Fix repetition penalty (#3139)

This commit is contained in:
Lianmin Zheng
2025-01-25 21:48:58 -08:00
committed by GitHub
parent 66283dbc0c
commit 4f118a39d7
2 changed files with 9 additions and 8 deletions

View File

@@ -67,6 +67,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]