Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[1, 2, 2],
|
||||
token_ids=[
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
], # This is the output ids of one request in three steps.
|
||||
expected_tensors={
|
||||
"frequency_penalties": self.tensor(
|
||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
||||
@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
||||
self.disabled = self._create_subject(frequency_penalty=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
||||
self.disabled = self._create_subject(min_new_tokens=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
||||
self.disabled = self._create_subject(presence_penalty=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user