From 57a404fd55a94116d3b7ff935c21cd2065cd9917 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Mon, 3 Mar 2025 12:41:38 -0500 Subject: [PATCH] Remove outdated test utils and fix links for the doc of sampling params (#3999) --- python/sglang/api.py | 2 +- python/sglang/lang/ir.py | 2 +- python/sglang/srt/sampling/sampling_params.py | 2 +- .../test/srt/sampling/penaltylib/utils.py | 344 ------------------ 4 files changed, 3 insertions(+), 347 deletions(-) delete mode 100644 python/sglang/test/srt/sampling/penaltylib/utils.py diff --git a/python/sglang/api.py b/python/sglang/api.py index 7ef306380..2bd39d5ee 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -94,7 +94,7 @@ def gen( regex: Optional[str] = None, json_schema: Optional[str] = None, ): - """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" + """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md""" if choices: return SglSelect( diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index d3a7430a8..0431d2c6b 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -457,7 +457,7 @@ class SglGen(SglExpr): regex: Optional[str] = None, json_schema: Optional[str] = None, ): - """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" + """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md""" super().__init__() self.name = name self.sampling_params = SglSamplingParams( diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 6a3c385a3..ffa2875e9 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -22,7 +22,7 @@ class SamplingParams: """ The sampling parameters. - See docs/references/sampling_params.md or + See docs/backend/sampling_params.md or https://docs.sglang.ai/backend/sampling_params.html for the documentation. """ diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py deleted file mode 100644 index 431efa9b3..000000000 --- a/python/sglang/test/srt/sampling/penaltylib/utils.py +++ /dev/null @@ -1,344 +0,0 @@ -import dataclasses -import enum -import unittest -from typing import Dict, List, Optional, Set, Tuple, Type - -import torch - -from sglang.srt.sampling.penaltylib.orchestrator import ( - BatchedPenalizerOrchestrator, - _BatchedPenalizer, - _BatchLike, -) - - -@dataclasses.dataclass -class MockSamplingParams: - frequency_penalty: float = 0.0 - min_new_tokens: int = 0 - stop_token_ids: List[int] = None - presence_penalty: float = 0.0 - repetition_penalty: float = 1.0 - - -@dataclasses.dataclass -class MockTokenizer: - eos_token_id: int - additional_stop_token_ids: Optional[List[int]] = None - - -@dataclasses.dataclass -class MockReq: - origin_input_ids: List[int] - sampling_params: MockSamplingParams - tokenizer: MockTokenizer - - -class StepType(enum.Enum): - INPUT = "input" - OUTPUT = "output" - - -@dataclasses.dataclass -class Step: - type: StepType - token_ids: List[int] - expected_tensors: Dict[str, torch.Tensor] - # assume initial logits are all 1 - expected_logits: torch.Tensor - - -@dataclasses.dataclass -class Subject: - sampling_params: MockSamplingParams - # first step must be input, which will be converted to Req - steps: List[Step] - eos_token_id: int = -1 - - def __post_init__(self): - if self.steps[0].type != StepType.INPUT: - raise ValueError("First step must be input") - - # each steps should have the same expected_tensors.keys() - for i in range(1, len(self.steps)): - if self.tensor_keys(i) != self.tensor_keys(): - raise ValueError( - f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}" - ) - - def tensor_keys(self, i: int = 0) -> Set[str]: - return set(self.steps[i].expected_tensors.keys()) - - def to_req(self) -> MockReq: - return MockReq( - origin_input_ids=self.steps[0].token_ids, - sampling_params=self.sampling_params, - tokenizer=MockTokenizer(eos_token_id=self.eos_token_id), - ) - - -@dataclasses.dataclass -class Case: - enabled: bool - test_subjects: List[Subject] - - def __post_init__(self): - # each test_subjects.steps should have the same expected_tensors.keys() - for i in range(1, len(self.test_subjects)): - if self.tensor_keys(i) != self.tensor_keys(): - raise ValueError( - f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}" - ) - - def tensor_keys(self, i: int = 0) -> List[str]: - return set(self.test_subjects[i].tensor_keys()) - - -class BaseBatchedPenalizerTest(unittest.TestCase): - Penalizer: Type[_BatchedPenalizer] - device = "cuda" - vocab_size = 5 - - enabled: Subject = None - disabled: Subject = None - - def setUp(self): - if self.__class__ == BaseBatchedPenalizerTest: - self.skipTest("Base class for penalizer tests") - - self.create_test_subjects() - self.create_test_cases() - - def tensor(self, data, **kwargs) -> torch.Tensor: - """ - Shortcut to create a tensor with device=self.device. - """ - return torch.tensor(data, **kwargs, device=self.device) - - def create_test_subjects(self) -> List[Subject]: - raise NotImplementedError() - - def create_test_cases(self): - self.test_cases = [ - Case(enabled=True, test_subjects=[self.enabled]), - Case(enabled=False, test_subjects=[self.disabled]), - Case(enabled=True, test_subjects=[self.enabled, self.disabled]), - ] - - def _create_penalizer( - self, case: Case - ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: - orchestrator = BatchedPenalizerOrchestrator( - vocab_size=self.vocab_size, - batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]), - device=self.device, - Penalizers={self.Penalizer}, - ) - - return orchestrator, orchestrator.penalizers[self.Penalizer] - - def test_is_required(self): - for case in self.test_cases: - with self.subTest(case=case): - _, penalizer = self._create_penalizer(case) - self.assertEqual(case.enabled, penalizer.is_required()) - - def test_prepare(self): - for case in self.test_cases: - with self.subTest(case=case): - orchestrator, penalizer = self._create_penalizer(case) - self.assertEqual(case.enabled, penalizer.is_prepared()) - - if case.enabled: - for key, tensor in { - key: torch.cat( - tensors=[ - subject.steps[0].expected_tensors[key] - for subject in case.test_subjects - ], - ) - for key in case.tensor_keys() - }.items(): - torch.testing.assert_close( - actual=getattr(penalizer, key), - expected=tensor, - msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", - ) - - original = torch.ones( - size=(len(case.test_subjects), self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - actual = orchestrator.apply(original.clone()) - expected = torch.cat( - tensors=[ - subject.steps[0].expected_logits - for subject in case.test_subjects - ], - ) - if actual is None: - actual = original - torch.testing.assert_close( - actual=actual, - expected=expected, - msg=f"logits\nactual={actual}\nexpected={expected}", - ) - - def test_teardown(self): - for case in self.test_cases: - with self.subTest(case=case): - _, penalizer = self._create_penalizer(case) - penalizer.teardown() - - for key in case.test_subjects[0].steps[0].expected_tensors.keys(): - self.assertIsNone(getattr(penalizer, key, None)) - - def test_filter(self): - for case in self.test_cases: - with self.subTest(case=case): - orchestrator, penalizer = self._create_penalizer(case) - - indices_to_keep = [0] - orchestrator.filter(indices_to_keep=indices_to_keep) - - filtered_subjects = [case.test_subjects[i] for i in indices_to_keep] - - if penalizer.is_required(): - self.assertTrue(penalizer.is_prepared()) - for key, tensor in { - key: torch.cat( - tensors=[ - subject.steps[0].expected_tensors[key] - for subject in filtered_subjects - ], - ) - for key in case.tensor_keys() - }.items(): - torch.testing.assert_close( - actual=getattr(penalizer, key), - expected=tensor, - msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", - ) - - actual_logits = orchestrator.apply( - torch.ones( - size=(len(filtered_subjects), self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - ) - if actual_logits is None: - continue - filtered_expected_logits = torch.cat( - tensors=[ - subject.steps[0].expected_logits - for subject in filtered_subjects - ], - ) - torch.testing.assert_close( - actual=actual_logits, - expected=filtered_expected_logits, - msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", - ) - - def test_merge_enabled_with_disabled(self): - enabled_test_case = self.test_cases[0] - disabled_test_case = self.test_cases[1] - - orchestrator, penalizer = self._create_penalizer(enabled_test_case) - theirs, _ = self._create_penalizer(disabled_test_case) - - orchestrator.merge(theirs) - - for key, tensor in { - key: torch.cat( - tensors=[ - enabled_test_case.test_subjects[0].steps[0].expected_tensors[key], - disabled_test_case.test_subjects[0].steps[0].expected_tensors[key], - ], - ) - for key in enabled_test_case.tensor_keys() - }.items(): - torch.testing.assert_close( - actual=getattr(penalizer, key), - expected=tensor, - msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", - ) - - def test_cumulate_apply_repeat(self): - for case in self.test_cases: - with self.subTest(case=case): - orchestrator, penalizer = self._create_penalizer(case) - - max_step = max(len(subject.steps) for subject in case.test_subjects) - for i in range(1, max_step): - orchestrator.filter( - indices_to_keep=[ - j - for j, subject in enumerate(case.test_subjects) - if i < len(subject.steps) - ] - ) - - filtered_subjects = [ - subject - for subject in case.test_subjects - if i < len(subject.steps) - ] - - inputs: List[List[int]] = [] - outputs: List[List[int]] = [] - for subject in filtered_subjects: - step = subject.steps[i] - if step.type == StepType.INPUT: - raise NotImplementedError() - else: - inputs.append([]) - outputs.append(step.token_ids) - - if any(outputs): - for j in range(max(len(x) for x in outputs)): - tmp_outputs = torch.tensor( - [x[j] for x in outputs], - dtype=torch.int32, - device=orchestrator.device, - ) - orchestrator.cumulate_output_tokens(tmp_outputs) - - if penalizer.is_required(): - self.assertTrue(penalizer.is_prepared()) - for key, tensor in { - key: torch.cat( - tensors=[ - subject.steps[i].expected_tensors[key] - for subject in filtered_subjects - ], - ) - for key in case.tensor_keys() - }.items(): - torch.testing.assert_close( - actual=getattr(penalizer, key), - expected=tensor, - msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", - ) - - original = torch.ones( - size=(len(filtered_subjects), self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - actual_logits = orchestrator.apply(original.clone()) - filtered_expected_logits = torch.cat( - tensors=[ - subject.steps[i].expected_logits - for subject in filtered_subjects - ], - ) - if actual_logits is None: - actual_logits = original - torch.testing.assert_close( - actual=actual_logits, - expected=filtered_expected_logits, - msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}", - )