Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
),
Step(
type=StepType.OUTPUT,
token_ids=[1, 2, 2],
token_ids=[
1,
2,
2,
], # This is the output ids of one request in three steps.
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0)

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
self.disabled = self._create_subject(min_new_tokens=0.0)

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0)

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
self.disabled = self._create_subject(repetition_penalty=1.0)