From 95a28019ba6c7288c1d2e747665d6a9dd005fdc2 Mon Sep 17 00:00:00 2001 From: Juwan Yoo Date: Thu, 8 Aug 2024 23:30:50 -0700 Subject: [PATCH] test: negative value testing for frequency, presence penalizers (#995) --- .../penalizers/test_frequency_penalty.py | 21 +++++++++++++++---- .../penalizers/test_presence_penalty.py | 21 +++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py index b659a04fc..59db353ab 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py +++ b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py @@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import ( Subject, ) -FREQUENCY_PENALTY = 0.12 - -class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest): +class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): Penalizer = BatchedFrequencyPenalizer + frequency_penalty: float + + def setUp(self): + if self.__class__ == BaseBatchedFrequencyPenalizerTest: + self.skipTest("Base class for frequency_penalty tests") + + super().setUp() def _create_subject(self, frequency_penalty: float) -> Subject: return Subject( @@ -72,9 +77,17 @@ class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest): ) def create_test_subjects(self) -> typing.List[Subject]: - self.enabled = self._create_subject(frequency_penalty=FREQUENCY_PENALTY) + self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty) self.disabled = self._create_subject(frequency_penalty=0.0) +class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest): + frequency_penalty = 0.12 + + +class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest): + frequency_penalty = -0.12 + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py index 30cb2b9a0..96cbf1082 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py +++ b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py @@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import ( Subject, ) -PRESENCE_PENALTY = 0.12 - -class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest): +class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest): Penalizer = BatchedPresencePenalizer + presence_penalty: float + + def setUp(self): + if self.__class__ == BaseBatchedPresencePenalizerTest: + self.skipTest("Base class for presence_penalty tests") + + super().setUp() def _create_subject(self, presence_penalty: float) -> Subject: return Subject( @@ -72,9 +77,17 @@ class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest): ) def create_test_subjects(self) -> typing.List[Subject]: - self.enabled = self._create_subject(presence_penalty=PRESENCE_PENALTY) + self.enabled = self._create_subject(presence_penalty=self.presence_penalty) self.disabled = self._create_subject(presence_penalty=0.0) +class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest): + presence_penalty = 0.12 + + +class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest): + presence_penalty = -0.12 + + if __name__ == "__main__": unittest.main()