test: negative value testing for frequency, presence penalizers (#995)

This commit is contained in:
Juwan Yoo
2024-08-08 23:30:50 -07:00
committed by GitHub
parent e040a2450b
commit 95a28019ba
2 changed files with 34 additions and 8 deletions

View File

@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import (
Subject,
)
PRESENCE_PENALTY = 0.12
class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedPresencePenalizer
presence_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedPresencePenalizerTest:
self.skipTest("Base class for presence_penalty tests")
super().setUp()
def _create_subject(self, presence_penalty: float) -> Subject:
return Subject(
@@ -72,9 +77,17 @@ class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
)
def create_test_subjects(self) -> typing.List[Subject]:
self.enabled = self._create_subject(presence_penalty=PRESENCE_PENALTY)
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0)
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
presence_penalty = 0.12
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
presence_penalty = -0.12
if __name__ == "__main__":
unittest.main()