Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1,5 +1,5 @@
import typing
import unittest
from typing import List
import torch
@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
),
Step(
type=StepType.OUTPUT,
token_ids=[1, 2, 2],
token_ids=[
1,
2,
2,
], # This is the output ids of one request in three steps.
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
],
)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0)