diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 20007d1dc..ce5fcc8d0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1019,7 +1019,7 @@ class ScheduleBatch: extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - if self.sampling_info is not None: + if self.sampling_info: if self.has_grammar: self.sampling_info.grammars = [req.grammar for req in self.reqs] else: @@ -1063,6 +1063,7 @@ class ScheduleBatch: out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, + sampling_info=dataclasses.replace(self.sampling_info), ) def __str__(self): @@ -1122,20 +1123,6 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo - def copy(self): - return dataclasses.replace(self, sampling_info=self.sampling_info.copy()) - - def to(self, device: str): - self.input_ids = self.input_ids.to(device, non_blocking=True) - self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True) - self.seq_lens = self.seq_lens.to(device, non_blocking=True) - self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True) - self.req_to_token_pool_records = [ - (x, y.to(device, non_blocking=True)) - for x, y in self.req_to_token_pool_records - ] - self.sampling_info.to(device) - @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 46f431adf..7012ddf63 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -931,14 +931,14 @@ class Scheduler: # Check finish conditions logprob_pt = 0 - for i, req in enumerate(batch.reqs): + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if req.is_retracted: continue if req.is_being_chunked <= 0: # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_ids[i]) + req.output_ids.append(next_token_id) req.check_finished() if req.finished(): @@ -947,7 +947,7 @@ class Scheduler: self.tree_cache.cache_unfinished_req(req) if req.grammar is not None: - req.grammar.accept_token(next_token_ids[i]) + req.grammar.accept_token(next_token_id) if req.return_logprob: logprob_pt += self.add_logprob_return_values( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 4900575ee..6e5bce36a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -16,6 +16,7 @@ limitations under the License. """A tensor parallel worker.""" import logging +import threading from typing import Optional from sglang.srt.configs.model_config import ModelConfig @@ -138,9 +139,15 @@ class TpModelWorker: forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) self.model_runner.forward(forward_batch) - def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + def forward_batch_generation( + self, + model_worker_batch: ModelWorkerBatch, + launch_event: Optional[threading.Event] = None, + ): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) + if launch_event: + launch_event.set() next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) return logits_output, next_token_ids diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index f5e43f348..6b42d3974 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -15,6 +15,7 @@ limitations under the License. """A tensor parallel worker.""" +import dataclasses import logging import threading import time @@ -107,7 +108,7 @@ class TpModelWorkerClient: # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( - model_worker_batch + model_worker_batch, self.launch_event ) # Update the future token ids map @@ -134,7 +135,6 @@ class TpModelWorkerClient: next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_event.record() - self.launch_event.set() self.output_queue.put((copy_event, logits_output, next_token_ids)) def resolve_batch_result(self, bid: int): @@ -159,7 +159,10 @@ class TpModelWorkerClient: def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Push a new batch to the queue - self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct)) + model_worker_batch.sampling_info = dataclasses.replace( + model_worker_batch.sampling_info + ) + self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) # Allocate output future objects bs = len(model_worker_batch.seq_lens) diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index c35e8edba..9c393d180 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -1,40 +1,34 @@ import abc import dataclasses -import typing +from typing import List, Set, Type, Union import torch @dataclasses.dataclass class _ReqLike: - origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] + origin_input_ids: List[int] @dataclasses.dataclass class _BatchLike: - reqs: typing.List[_ReqLike] + reqs: List[_ReqLike] def batch_size(self): return len(self.reqs) class BatchedPenalizerOrchestrator: - batch: _BatchLike - device: str - vocab_size: int - penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"] - def __init__( self, vocab_size: int, batch: _BatchLike, device: str, - Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], + Penalizers: Set[Type["_BatchedPenalizer"]], ): self.vocab_size = vocab_size self.batch = batch self.device = device - self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} is_required = False @@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator: is_required |= pen_is_required self.is_required = is_required + input_ids = [ + torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device) + for req in self.reqs() + ] if self.is_required: - self.cumulate_input_tokens( - input_ids=[req.origin_input_ids for req in self.reqs()] - ) + self.cumulate_input_tokens(input_ids=input_ids) def reqs(self): return self.batch.reqs @@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator: def batch_size(self): return self.batch.batch_size() - def cumulate_input_tokens( - self, - input_ids: typing.Union[ - typing.List[torch.Tensor], typing.List[typing.List[int]] - ], - ): + def cumulate_input_tokens(self, input_ids: List[torch.Tensor]): """ Feed the input tokens to the penalizers. Args: - input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. + input_ids (List[torch.Tensor]): The input tokens. """ token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) for penalizer in self.penalizers.values(): penalizer.cumulate_input_tokens(input_ids=token_ids) - def cumulate_output_tokens( - self, - output_ids: typing.Union[ - typing.List[torch.Tensor], typing.List[typing.List[int]] - ], - ): + def cumulate_output_tokens(self, output_ids: torch.Tensor): """ Feed the output tokens to the penalizers. Args: - output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. + output_ids (torch.Tensor): The output tokens. """ if not self.is_required: return @@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator: def filter( self, - indices_to_keep: typing.List[int], + indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor = None, ): """ Filter the penalizers based on the indices to keep in the batch. Args: - indices_to_keep (typing.List[int]): List of indices to keep in the batch. + indices_to_keep (List[int]): List of indices to keep in the batch. indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. """ if not self.is_required: @@ -174,32 +160,18 @@ class _TokenIDs: Attributes: orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. - token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. + token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs. cached_counts (torch.Tensor): The cached occurrence count tensor. """ - orchestrator: BatchedPenalizerOrchestrator - token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]] - cached_counts: torch.Tensor = None - def __init__( self, orchestrator: BatchedPenalizerOrchestrator, - token_ids: typing.Union[ - typing.List[torch.Tensor], typing.List[typing.List[int]] - ], + token_ids: Union[torch.Tensor, List[torch.Tensor]], ): self.orchestrator = orchestrator - - if not isinstance(token_ids[0], torch.Tensor): - token_ids = [ - torch.tensor( - data=ids, dtype=torch.int64, device=self.orchestrator.device - ) - for ids in token_ids - ] - self.token_ids = token_ids + self.cached_counts = None def occurrence_count(self) -> torch.Tensor: """ @@ -213,30 +185,34 @@ class _TokenIDs: token_ids = self.token_ids - if isinstance(token_ids, torch.Tensor): - token_ids = token_ids.unsqueeze(1) - - # needs to be long to be used as index in scatter_add - if token_ids.dtype != torch.int64: - token_ids = token_ids.to(torch.int64) - - padded_token_ids = torch.nn.utils.rnn.pad_sequence( - sequences=token_ids, - batch_first=True, - padding_value=self.orchestrator.vocab_size, - ) - - self.cached_counts = torch.zeros( - size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), - dtype=torch.int64, - device=self.orchestrator.device, - ).scatter_add_( - dim=1, - index=padded_token_ids, - src=torch.ones_like(padded_token_ids), - )[ - :, : self.orchestrator.vocab_size - ] + if isinstance(token_ids, list): + # TODO: optimize this part + padded_token_ids = torch.nn.utils.rnn.pad_sequence( + sequences=token_ids, + batch_first=True, + padding_value=self.orchestrator.vocab_size, + ) + self.cached_counts = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + dtype=torch.int64, + device=self.orchestrator.device, + ).scatter_add_( + dim=1, + index=padded_token_ids, + src=torch.ones_like(padded_token_ids), + )[ + :, : self.orchestrator.vocab_size + ] + else: + # TODO: optimize this part. We do not need to create this big tensor every time. + # We can directly apply the results on the logits. + self.cached_counts = torch.zeros( + size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size), + device=self.orchestrator.device, + ) + self.cached_counts[ + torch.arange(len(token_ids), device=self.orchestrator.device), token_ids + ] = 1 return self.cached_counts @@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC): An abstract class for a batched penalizer. """ - orchestrator: BatchedPenalizerOrchestrator - _is_prepared: bool = False - def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self.orchestrator = orchestrator + self._is_prepared = False def is_prepared(self) -> bool: return self._is_prepared @@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC): return self._apply(logits=logits) - def filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): if not self.is_prepared(): return @@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC): pass @abc.abstractmethod - def _filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): """ Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. """ diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py index 178cb54b2..34fa5abbf 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py @@ -1,8 +1,8 @@ -import typing +from typing import List import torch -from ..orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs class BatchedFrequencyPenalizer(_BatchedPenalizer): @@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ) def _teardown(self): - del self.frequency_penalties - del self.cumulated_frequency_penalties - self.frequency_penalties = None self.cumulated_frequency_penalties = None @@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): logits -= self.cumulated_frequency_penalties return logits - def _filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ indices_tensor_to_keep diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py index cc97a2eac..0e27c7e5a 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py @@ -1,8 +1,8 @@ -import typing +from typing import List import torch -from ..orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs class BatchedMinNewTokensPenalizer(_BatchedPenalizer): @@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ) def _teardown(self): - del self.min_new_tokens - del self.stop_token_penalties - del self.len_output_tokens - self.min_new_tokens = None self.stop_token_penalties = None self.len_output_tokens = None @@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): logits[mask] += self.stop_token_penalties[mask] return logits - def _filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py index 0593fddc9..f86aa4a2d 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py @@ -1,8 +1,8 @@ -import typing +from typing import List import torch -from ..orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs class BatchedPresencePenalizer(_BatchedPenalizer): @@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ) def _teardown(self): - del self.presence_penalties - del self.cumulated_presence_penalties - self.presence_penalties = None self.cumulated_presence_penalties = None @@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer): logits -= self.cumulated_presence_penalties return logits - def _filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] self.cumulated_presence_penalties = self.cumulated_presence_penalties[ indices_tensor_to_keep diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index ea32addc2..4c293b895 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -1,8 +1,8 @@ -import typing +from typing import List import torch -from ..orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): ) def _teardown(self): - del self.repetition_penalties - del self.cumulated_repetition_penalties - self.repetition_penalties = None self.cumulated_repetition_penalties = None @@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): logits * self.cumulated_repetition_penalties, ) - def _filter( - self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor - ): + def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ indices_tensor_to_keep diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 61aa341fd..41b88e966 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -27,10 +27,10 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int + grammars: Optional[List] = None logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None - grammars: Optional[List] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None @@ -211,25 +211,3 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) - - def copy(self): - return SamplingBatchInfo( - temperatures=self.temperatures, - top_ps=self.top_ps, - top_ks=self.top_ks, - min_ps=self.min_ps, - is_all_greedy=self.is_all_greedy, - need_min_p_sampling=self.need_min_p_sampling, - vocab_size=self.vocab_size, - device=self.device, - ) - - def to(self, device: str): - for item in [ - "temperatures", - "top_ps", - "top_ks", - "min_ps", - ]: - value = getattr(self, item) - setattr(self, item, value.to(device, non_blocking=True)) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index e5b876f6d..df5439979 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -24,7 +24,6 @@ class SamplingParams: def __init__( self, max_new_tokens: int = 128, - min_new_tokens: int = 0, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, @@ -34,6 +33,7 @@ class SamplingParams: frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repetition_penalty: float = 1.0, + min_new_tokens: int = 0, spaces_between_special_tokens: bool = True, regex: Optional[str] = None, n: int = 1, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 394bb519b..204e98da1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -782,7 +782,7 @@ class PortArgs: @staticmethod def init_new(server_args) -> "PortArgs": - port = server_args.port + 42 + port = server_args.port + random.randint(100, 1000) while True: if is_port_available(port): break diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py index 4acfa5a41..431efa9b3 100644 --- a/python/sglang/test/srt/sampling/penaltylib/utils.py +++ b/python/sglang/test/srt/sampling/penaltylib/utils.py @@ -1,7 +1,7 @@ import dataclasses import enum -import typing import unittest +from typing import Dict, List, Optional, Set, Tuple, Type import torch @@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import ( class MockSamplingParams: frequency_penalty: float = 0.0 min_new_tokens: int = 0 - stop_token_ids: typing.List[int] = None + stop_token_ids: List[int] = None presence_penalty: float = 0.0 repetition_penalty: float = 1.0 @@ -24,12 +24,12 @@ class MockSamplingParams: @dataclasses.dataclass class MockTokenizer: eos_token_id: int - additional_stop_token_ids: typing.Optional[typing.List[int]] = None + additional_stop_token_ids: Optional[List[int]] = None @dataclasses.dataclass class MockReq: - origin_input_ids: typing.List[int] + origin_input_ids: List[int] sampling_params: MockSamplingParams tokenizer: MockTokenizer @@ -42,8 +42,8 @@ class StepType(enum.Enum): @dataclasses.dataclass class Step: type: StepType - token_ids: typing.List[int] - expected_tensors: typing.Dict[str, torch.Tensor] + token_ids: List[int] + expected_tensors: Dict[str, torch.Tensor] # assume initial logits are all 1 expected_logits: torch.Tensor @@ -52,7 +52,7 @@ class Step: class Subject: sampling_params: MockSamplingParams # first step must be input, which will be converted to Req - steps: typing.List[Step] + steps: List[Step] eos_token_id: int = -1 def __post_init__(self): @@ -66,7 +66,7 @@ class Subject: f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}" ) - def tensor_keys(self, i: int = 0) -> typing.Set[str]: + def tensor_keys(self, i: int = 0) -> Set[str]: return set(self.steps[i].expected_tensors.keys()) def to_req(self) -> MockReq: @@ -80,7 +80,7 @@ class Subject: @dataclasses.dataclass class Case: enabled: bool - test_subjects: typing.List[Subject] + test_subjects: List[Subject] def __post_init__(self): # each test_subjects.steps should have the same expected_tensors.keys() @@ -90,12 +90,12 @@ class Case: f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}" ) - def tensor_keys(self, i: int = 0) -> typing.List[str]: + def tensor_keys(self, i: int = 0) -> List[str]: return set(self.test_subjects[i].tensor_keys()) class BaseBatchedPenalizerTest(unittest.TestCase): - Penalizer: typing.Type[_BatchedPenalizer] + Penalizer: Type[_BatchedPenalizer] device = "cuda" vocab_size = 5 @@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase): """ return torch.tensor(data, **kwargs, device=self.device) - def create_test_subjects(self) -> typing.List[Subject]: + def create_test_subjects(self) -> List[Subject]: raise NotImplementedError() def create_test_cases(self): @@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase): def _create_penalizer( self, case: Case - ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: + ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: orchestrator = BatchedPenalizerOrchestrator( vocab_size=self.vocab_size, batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]), @@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase): if i < len(subject.steps) ] - inputs: typing.List[typing.List[int]] = [] - outputs: typing.List[typing.List[int]] = [] + inputs: List[List[int]] = [] + outputs: List[List[int]] = [] for subject in filtered_subjects: step = subject.steps[i] if step.type == StepType.INPUT: - inputs.append(step.token_ids) - outputs.append([]) + raise NotImplementedError() else: inputs.append([]) outputs.append(step.token_ids) - if any(inputs): - orchestrator.cumulate_input_tokens(inputs) - if any(outputs): - orchestrator.cumulate_output_tokens(outputs) + for j in range(max(len(x) for x in outputs)): + tmp_outputs = torch.tensor( + [x[j] for x in outputs], + dtype=torch.int32, + device=orchestrator.device, + ) + orchestrator.cumulate_output_tokens(tmp_outputs) if penalizer.is_required(): self.assertTrue(penalizer.is_prepared()) diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 106196a6a..b99606fc1 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,3 +1,8 @@ +""" +Usage: +python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens +""" + import unittest import sglang as sgl @@ -68,7 +73,7 @@ class TestSRTBackend(unittest.TestCase): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - assert accuracy > 0.71, f"{accuracy=}" + self.assertGreater(accuracy, 0.71) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py index 59db353ab..e8a8fe033 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py +++ b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py @@ -1,5 +1,5 @@ -import typing import unittest +from typing import List import torch @@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): ), Step( type=StepType.OUTPUT, - token_ids=[1, 2, 2], + token_ids=[ + 1, + 2, + 2, + ], # This is the output ids of one request in three steps. expected_tensors={ "frequency_penalties": self.tensor( [[frequency_penalty] * self.vocab_size], dtype=torch.float32 @@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): ], ) - def create_test_subjects(self) -> typing.List[Subject]: + def create_test_subjects(self) -> List[Subject]: self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty) self.disabled = self._create_subject(frequency_penalty=0.0) diff --git a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py index 1984aafe5..298dd2cc1 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py +++ b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py @@ -1,5 +1,5 @@ -import typing import unittest +from typing import List import torch @@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest): ], ) - def create_test_subjects(self) -> typing.List[Subject]: + def create_test_subjects(self) -> List[Subject]: self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS) self.disabled = self._create_subject(min_new_tokens=0.0) diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py index 96cbf1082..b249283ac 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py +++ b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py @@ -1,5 +1,5 @@ -import typing import unittest +from typing import List import torch @@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest): ], ) - def create_test_subjects(self) -> typing.List[Subject]: + def create_test_subjects(self) -> List[Subject]: self.enabled = self._create_subject(presence_penalty=self.presence_penalty) self.disabled = self._create_subject(presence_penalty=0.0) diff --git a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py index e3751c14a..2f8671391 100644 --- a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py +++ b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py @@ -1,5 +1,5 @@ -import typing import unittest +from typing import List import torch @@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest): ], ) - def create_test_subjects(self) -> typing.List[Subject]: + def create_test_subjects(self) -> List[Subject]: self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY) self.disabled = self._create_subject(repetition_penalty=1.0)