Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1,7 +1,7 @@
import dataclasses
import enum
import typing
import unittest
from typing import Dict, List, Optional, Set, Tuple, Type
import torch
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
class MockSamplingParams:
frequency_penalty: float = 0.0
min_new_tokens: int = 0
stop_token_ids: typing.List[int] = None
stop_token_ids: List[int] = None
presence_penalty: float = 0.0
repetition_penalty: float = 1.0
@@ -24,12 +24,12 @@ class MockSamplingParams:
@dataclasses.dataclass
class MockTokenizer:
eos_token_id: int
additional_stop_token_ids: typing.Optional[typing.List[int]] = None
additional_stop_token_ids: Optional[List[int]] = None
@dataclasses.dataclass
class MockReq:
origin_input_ids: typing.List[int]
origin_input_ids: List[int]
sampling_params: MockSamplingParams
tokenizer: MockTokenizer
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
@dataclasses.dataclass
class Step:
type: StepType
token_ids: typing.List[int]
expected_tensors: typing.Dict[str, torch.Tensor]
token_ids: List[int]
expected_tensors: Dict[str, torch.Tensor]
# assume initial logits are all 1
expected_logits: torch.Tensor
@@ -52,7 +52,7 @@ class Step:
class Subject:
sampling_params: MockSamplingParams
# first step must be input, which will be converted to Req
steps: typing.List[Step]
steps: List[Step]
eos_token_id: int = -1
def __post_init__(self):
@@ -66,7 +66,7 @@ class Subject:
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
)
def tensor_keys(self, i: int = 0) -> typing.Set[str]:
def tensor_keys(self, i: int = 0) -> Set[str]:
return set(self.steps[i].expected_tensors.keys())
def to_req(self) -> MockReq:
@@ -80,7 +80,7 @@ class Subject:
@dataclasses.dataclass
class Case:
enabled: bool
test_subjects: typing.List[Subject]
test_subjects: List[Subject]
def __post_init__(self):
# each test_subjects.steps should have the same expected_tensors.keys()
@@ -90,12 +90,12 @@ class Case:
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
)
def tensor_keys(self, i: int = 0) -> typing.List[str]:
def tensor_keys(self, i: int = 0) -> List[str]:
return set(self.test_subjects[i].tensor_keys())
class BaseBatchedPenalizerTest(unittest.TestCase):
Penalizer: typing.Type[_BatchedPenalizer]
Penalizer: Type[_BatchedPenalizer]
device = "cuda"
vocab_size = 5
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
"""
return torch.tensor(data, **kwargs, device=self.device)
def create_test_subjects(self) -> typing.List[Subject]:
def create_test_subjects(self) -> List[Subject]:
raise NotImplementedError()
def create_test_cases(self):
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
def _create_penalizer(
self, case: Case
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
orchestrator = BatchedPenalizerOrchestrator(
vocab_size=self.vocab_size,
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
if i < len(subject.steps)
]
inputs: typing.List[typing.List[int]] = []
outputs: typing.List[typing.List[int]] = []
inputs: List[List[int]] = []
outputs: List[List[int]] = []
for subject in filtered_subjects:
step = subject.steps[i]
if step.type == StepType.INPUT:
inputs.append(step.token_ids)
outputs.append([])
raise NotImplementedError()
else:
inputs.append([])
outputs.append(step.token_ids)
if any(inputs):
orchestrator.cumulate_input_tokens(inputs)
if any(outputs):
orchestrator.cumulate_output_tokens(outputs)
for j in range(max(len(x) for x in outputs)):
tmp_outputs = torch.tensor(
[x[j] for x in outputs],
dtype=torch.int32,
device=orchestrator.device,
)
orchestrator.cumulate_output_tokens(tmp_outputs)
if penalizer.is_required():
self.assertTrue(penalizer.is_prepared())