338 lines
12 KiB
Python
338 lines
12 KiB
Python
import dataclasses
|
|
import enum
|
|
import typing
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from sglang.srt.sampling.penaltylib.orchestrator import (
|
|
BatchedPenalizerOrchestrator,
|
|
_BatchedPenalizer,
|
|
_BatchLike,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockSamplingParams:
|
|
frequency_penalty: float = 0.0
|
|
min_new_tokens: int = 0
|
|
stop_token_ids: typing.List[int] = None
|
|
presence_penalty: float = 0.0
|
|
repetition_penalty: float = 1.0
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockTokenizer:
|
|
eos_token_id: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockReq:
|
|
origin_input_ids: typing.List[int]
|
|
sampling_params: MockSamplingParams
|
|
tokenizer: MockTokenizer
|
|
|
|
|
|
class StepType(enum.Enum):
|
|
INPUT = "input"
|
|
OUTPUT = "output"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Step:
|
|
type: StepType
|
|
token_ids: typing.List[int]
|
|
expected_tensors: typing.Dict[str, torch.Tensor]
|
|
# assume initial logits are all 1
|
|
expected_logits: torch.Tensor
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Subject:
|
|
sampling_params: MockSamplingParams
|
|
# first step must be input, which will be converted to Req
|
|
steps: typing.List[Step]
|
|
eos_token_id: int = -1
|
|
|
|
def __post_init__(self):
|
|
if self.steps[0].type != StepType.INPUT:
|
|
raise ValueError("First step must be input")
|
|
|
|
# each steps should have the same expected_tensors.keys()
|
|
for i in range(1, len(self.steps)):
|
|
if self.tensor_keys(i) != self.tensor_keys():
|
|
raise ValueError(
|
|
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
|
)
|
|
|
|
def tensor_keys(self, i: int = 0) -> typing.Set[str]:
|
|
return set(self.steps[i].expected_tensors.keys())
|
|
|
|
def to_req(self) -> MockReq:
|
|
return MockReq(
|
|
origin_input_ids=self.steps[0].token_ids,
|
|
sampling_params=self.sampling_params,
|
|
tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Case:
|
|
enabled: bool
|
|
test_subjects: typing.List[Subject]
|
|
|
|
def __post_init__(self):
|
|
# each test_subjects.steps should have the same expected_tensors.keys()
|
|
for i in range(1, len(self.test_subjects)):
|
|
if self.tensor_keys(i) != self.tensor_keys():
|
|
raise ValueError(
|
|
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
|
)
|
|
|
|
def tensor_keys(self, i: int = 0) -> typing.List[str]:
|
|
return set(self.test_subjects[i].tensor_keys())
|
|
|
|
|
|
class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
Penalizer: typing.Type[_BatchedPenalizer]
|
|
device = "cuda"
|
|
vocab_size = 5
|
|
|
|
enabled: Subject = None
|
|
disabled: Subject = None
|
|
|
|
def setUp(self):
|
|
if self.__class__ == BaseBatchedPenalizerTest:
|
|
self.skipTest("Base class for penalizer tests")
|
|
|
|
self.create_test_subjects()
|
|
self.create_test_cases()
|
|
|
|
def tensor(self, data, **kwargs) -> torch.Tensor:
|
|
"""
|
|
Shortcut to create a tensor with device=self.device.
|
|
"""
|
|
return torch.tensor(data, **kwargs, device=self.device)
|
|
|
|
def create_test_subjects(self) -> typing.List[Subject]:
|
|
raise NotImplementedError()
|
|
|
|
def create_test_cases(self):
|
|
self.test_cases = [
|
|
Case(enabled=True, test_subjects=[self.enabled]),
|
|
Case(enabled=False, test_subjects=[self.disabled]),
|
|
Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
|
|
]
|
|
|
|
def _create_penalizer(
|
|
self, case: Case
|
|
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
|
orchestrator = BatchedPenalizerOrchestrator(
|
|
vocab_size=self.vocab_size,
|
|
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
|
device=self.device,
|
|
Penalizers={self.Penalizer},
|
|
)
|
|
|
|
return orchestrator, orchestrator.penalizers[self.Penalizer]
|
|
|
|
def test_is_required(self):
|
|
for case in self.test_cases:
|
|
with self.subTest(case=case):
|
|
_, penalizer = self._create_penalizer(case)
|
|
self.assertEqual(case.enabled, penalizer.is_required())
|
|
|
|
def test_prepare(self):
|
|
for case in self.test_cases:
|
|
with self.subTest(case=case):
|
|
orchestrator, penalizer = self._create_penalizer(case)
|
|
self.assertEqual(case.enabled, penalizer.is_prepared())
|
|
|
|
if case.enabled:
|
|
for key, tensor in {
|
|
key: torch.cat(
|
|
tensors=[
|
|
subject.steps[0].expected_tensors[key]
|
|
for subject in case.test_subjects
|
|
],
|
|
)
|
|
for key in case.tensor_keys()
|
|
}.items():
|
|
torch.testing.assert_close(
|
|
actual=getattr(penalizer, key),
|
|
expected=tensor,
|
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
|
)
|
|
|
|
actual = orchestrator.apply(
|
|
torch.ones(
|
|
size=(len(case.test_subjects), self.vocab_size),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
)
|
|
expected = torch.cat(
|
|
tensors=[
|
|
subject.steps[0].expected_logits
|
|
for subject in case.test_subjects
|
|
],
|
|
)
|
|
torch.testing.assert_close(
|
|
actual=actual,
|
|
expected=expected,
|
|
msg=f"logits\nactual={actual}\nexpected={expected}",
|
|
)
|
|
|
|
def test_teardown(self):
|
|
for case in self.test_cases:
|
|
with self.subTest(case=case):
|
|
_, penalizer = self._create_penalizer(case)
|
|
penalizer.teardown()
|
|
|
|
for key in case.test_subjects[0].steps[0].expected_tensors.keys():
|
|
self.assertIsNone(getattr(penalizer, key, None))
|
|
|
|
def test_filter(self):
|
|
for case in self.test_cases:
|
|
with self.subTest(case=case):
|
|
orchestrator, penalizer = self._create_penalizer(case)
|
|
|
|
indices_to_keep = [0]
|
|
orchestrator.filter(indices_to_keep=indices_to_keep)
|
|
|
|
filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
|
|
|
|
if penalizer.is_required():
|
|
self.assertTrue(penalizer.is_prepared())
|
|
for key, tensor in {
|
|
key: torch.cat(
|
|
tensors=[
|
|
subject.steps[0].expected_tensors[key]
|
|
for subject in filtered_subjects
|
|
],
|
|
)
|
|
for key in case.tensor_keys()
|
|
}.items():
|
|
torch.testing.assert_close(
|
|
actual=getattr(penalizer, key),
|
|
expected=tensor,
|
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
|
)
|
|
|
|
actual_logits = orchestrator.apply(
|
|
torch.ones(
|
|
size=(len(filtered_subjects), self.vocab_size),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
)
|
|
filtered_expected_logits = torch.cat(
|
|
tensors=[
|
|
subject.steps[0].expected_logits
|
|
for subject in filtered_subjects
|
|
],
|
|
)
|
|
torch.testing.assert_close(
|
|
actual=actual_logits,
|
|
expected=filtered_expected_logits,
|
|
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
|
)
|
|
|
|
def test_merge_enabled_with_disabled(self):
|
|
enabled_test_case = self.test_cases[0]
|
|
disabled_test_case = self.test_cases[1]
|
|
|
|
orchestrator, penalizer = self._create_penalizer(enabled_test_case)
|
|
theirs, _ = self._create_penalizer(disabled_test_case)
|
|
|
|
orchestrator.merge(theirs)
|
|
|
|
for key, tensor in {
|
|
key: torch.cat(
|
|
tensors=[
|
|
enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
|
disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
|
],
|
|
)
|
|
for key in enabled_test_case.tensor_keys()
|
|
}.items():
|
|
torch.testing.assert_close(
|
|
actual=getattr(penalizer, key),
|
|
expected=tensor,
|
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
|
)
|
|
|
|
def test_cumulate_apply_repeat(self):
|
|
for case in self.test_cases:
|
|
with self.subTest(case=case):
|
|
orchestrator, penalizer = self._create_penalizer(case)
|
|
|
|
max_step = max(len(subject.steps) for subject in case.test_subjects)
|
|
for i in range(1, max_step):
|
|
orchestrator.filter(
|
|
indices_to_keep=[
|
|
j
|
|
for j, subject in enumerate(case.test_subjects)
|
|
if i < len(subject.steps)
|
|
]
|
|
)
|
|
|
|
filtered_subjects = [
|
|
subject
|
|
for subject in case.test_subjects
|
|
if i < len(subject.steps)
|
|
]
|
|
|
|
inputs: typing.List[typing.List[int]] = []
|
|
outputs: typing.List[typing.List[int]] = []
|
|
for subject in filtered_subjects:
|
|
step = subject.steps[i]
|
|
if step.type == StepType.INPUT:
|
|
inputs.append(step.token_ids)
|
|
outputs.append([])
|
|
else:
|
|
inputs.append([])
|
|
outputs.append(step.token_ids)
|
|
|
|
if any(inputs):
|
|
orchestrator.cumulate_input_tokens(inputs)
|
|
|
|
if any(outputs):
|
|
orchestrator.cumulate_output_tokens(outputs)
|
|
|
|
if penalizer.is_required():
|
|
self.assertTrue(penalizer.is_prepared())
|
|
for key, tensor in {
|
|
key: torch.cat(
|
|
tensors=[
|
|
subject.steps[i].expected_tensors[key]
|
|
for subject in filtered_subjects
|
|
],
|
|
)
|
|
for key in case.tensor_keys()
|
|
}.items():
|
|
torch.testing.assert_close(
|
|
actual=getattr(penalizer, key),
|
|
expected=tensor,
|
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
|
)
|
|
|
|
actual_logits = orchestrator.apply(
|
|
torch.ones(
|
|
size=(len(filtered_subjects), self.vocab_size),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
)
|
|
filtered_expected_logits = torch.cat(
|
|
tensors=[
|
|
subject.steps[i].expected_logits
|
|
for subject in filtered_subjects
|
|
],
|
|
)
|
|
torch.testing.assert_close(
|
|
actual=actual_logits,
|
|
expected=filtered_expected_logits,
|
|
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
|
)
|