Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user