Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import sglang as sgl
|
||||
@@ -68,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
|
||||
# Run twice to capture more bugs
|
||||
for _ in range(2):
|
||||
accuracy, latency = test_hellaswag_select()
|
||||
assert accuracy > 0.71, f"{accuracy=}"
|
||||
self.assertGreater(accuracy, 0.71)
|
||||
|
||||
def test_gen_min_new_tokens(self):
|
||||
test_gen_min_new_tokens()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[1, 2, 2],
|
||||
token_ids=[
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
], # This is the output ids of one request in three steps.
|
||||
expected_tensors={
|
||||
"frequency_penalties": self.tensor(
|
||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
||||
@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
||||
self.disabled = self._create_subject(frequency_penalty=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
||||
self.disabled = self._create_subject(min_new_tokens=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
||||
self.disabled = self._create_subject(presence_penalty=0.0)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import typing
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user