Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing
|
||||
import unittest
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
class MockSamplingParams:
|
||||
frequency_penalty: float = 0.0
|
||||
min_new_tokens: int = 0
|
||||
stop_token_ids: typing.List[int] = None
|
||||
stop_token_ids: List[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
|
||||
@@ -24,12 +24,12 @@ class MockSamplingParams:
|
||||
@dataclasses.dataclass
|
||||
class MockTokenizer:
|
||||
eos_token_id: int
|
||||
additional_stop_token_ids: typing.Optional[typing.List[int]] = None
|
||||
additional_stop_token_ids: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockReq:
|
||||
origin_input_ids: typing.List[int]
|
||||
origin_input_ids: List[int]
|
||||
sampling_params: MockSamplingParams
|
||||
tokenizer: MockTokenizer
|
||||
|
||||
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
|
||||
@dataclasses.dataclass
|
||||
class Step:
|
||||
type: StepType
|
||||
token_ids: typing.List[int]
|
||||
expected_tensors: typing.Dict[str, torch.Tensor]
|
||||
token_ids: List[int]
|
||||
expected_tensors: Dict[str, torch.Tensor]
|
||||
# assume initial logits are all 1
|
||||
expected_logits: torch.Tensor
|
||||
|
||||
@@ -52,7 +52,7 @@ class Step:
|
||||
class Subject:
|
||||
sampling_params: MockSamplingParams
|
||||
# first step must be input, which will be converted to Req
|
||||
steps: typing.List[Step]
|
||||
steps: List[Step]
|
||||
eos_token_id: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -66,7 +66,7 @@ class Subject:
|
||||
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
||||
)
|
||||
|
||||
def tensor_keys(self, i: int = 0) -> typing.Set[str]:
|
||||
def tensor_keys(self, i: int = 0) -> Set[str]:
|
||||
return set(self.steps[i].expected_tensors.keys())
|
||||
|
||||
def to_req(self) -> MockReq:
|
||||
@@ -80,7 +80,7 @@ class Subject:
|
||||
@dataclasses.dataclass
|
||||
class Case:
|
||||
enabled: bool
|
||||
test_subjects: typing.List[Subject]
|
||||
test_subjects: List[Subject]
|
||||
|
||||
def __post_init__(self):
|
||||
# each test_subjects.steps should have the same expected_tensors.keys()
|
||||
@@ -90,12 +90,12 @@ class Case:
|
||||
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
||||
)
|
||||
|
||||
def tensor_keys(self, i: int = 0) -> typing.List[str]:
|
||||
def tensor_keys(self, i: int = 0) -> List[str]:
|
||||
return set(self.test_subjects[i].tensor_keys())
|
||||
|
||||
|
||||
class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
Penalizer: typing.Type[_BatchedPenalizer]
|
||||
Penalizer: Type[_BatchedPenalizer]
|
||||
device = "cuda"
|
||||
vocab_size = 5
|
||||
|
||||
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
"""
|
||||
return torch.tensor(data, **kwargs, device=self.device)
|
||||
|
||||
def create_test_subjects(self) -> typing.List[Subject]:
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_test_cases(self):
|
||||
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
|
||||
def _create_penalizer(
|
||||
self, case: Case
|
||||
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
||||
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
||||
orchestrator = BatchedPenalizerOrchestrator(
|
||||
vocab_size=self.vocab_size,
|
||||
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
||||
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
||||
if i < len(subject.steps)
|
||||
]
|
||||
|
||||
inputs: typing.List[typing.List[int]] = []
|
||||
outputs: typing.List[typing.List[int]] = []
|
||||
inputs: List[List[int]] = []
|
||||
outputs: List[List[int]] = []
|
||||
for subject in filtered_subjects:
|
||||
step = subject.steps[i]
|
||||
if step.type == StepType.INPUT:
|
||||
inputs.append(step.token_ids)
|
||||
outputs.append([])
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
inputs.append([])
|
||||
outputs.append(step.token_ids)
|
||||
|
||||
if any(inputs):
|
||||
orchestrator.cumulate_input_tokens(inputs)
|
||||
|
||||
if any(outputs):
|
||||
orchestrator.cumulate_output_tokens(outputs)
|
||||
for j in range(max(len(x) for x in outputs)):
|
||||
tmp_outputs = torch.tensor(
|
||||
[x[j] for x in outputs],
|
||||
dtype=torch.int32,
|
||||
device=orchestrator.device,
|
||||
)
|
||||
orchestrator.cumulate_output_tokens(tmp_outputs)
|
||||
|
||||
if penalizer.is_required():
|
||||
self.assertTrue(penalizer.is_prepared())
|
||||
|
||||
Reference in New Issue
Block a user