Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
self.disabled = self._create_subject(repetition_penalty=1.0)