diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6cd5127bd..000f8ecdc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -515,11 +515,11 @@ class ScheduleBatch: assert seq_len - pre_len == req.extend_input_len if pre_len > 0: - self.req_to_token_pool.req_to_token[req.req_pool_idx][ - :pre_len - ] = req.prefix_indices + self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = ( + req.prefix_indices + ) - self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( out_cache_loc[pt : pt + req.extend_input_len] ) @@ -535,10 +535,15 @@ class ScheduleBatch: pt += req.extend_input_len # Set fields - with out_cache_loc.device: - self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) - self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc @@ -782,8 +787,8 @@ class ScheduleBatch: return self.reqs = [self.reqs[i] for i in keep_indices] - new_indices = torch.tensor( - keep_indices, dtype=torch.int32, device=self.seq_lens.device + new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( + self.device, non_blocking=True ) self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e0588c407..16c43dd16 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -150,6 +150,7 @@ class Scheduler: nccl_port=port_args.nccl_port, ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group + self.device = self.tp_worker.device # Get token and memory info from the model worker ( @@ -758,9 +759,7 @@ class Scheduler: if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = ( logits_output.next_token_logprobs[ - torch.arange( - len(next_token_ids), device=next_token_ids.device - ), + torch.arange(len(next_token_ids), device=self.device), next_token_ids, ].tolist() ) @@ -828,7 +827,7 @@ class Scheduler: # Move logprobs to cpu if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), + torch.arange(len(next_token_ids), device=self.device), next_token_ids, ].tolist() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c8afc1572..f5ae3b00a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -90,7 +90,7 @@ class BaseTokenToKVPool: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - return select_index.to(self.device) + return select_index.to(self.device, non_blocking=True) def free(self, free_index: torch.Tensor): if self.is_not_in_free_group: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 555e3db95..d18044bff 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -135,25 +135,22 @@ class ForwardBatch: # Init position information if not ret.forward_mode.is_decode(): - ret.positions = torch.tensor( - np.concatenate( - [ - np.arange(prefix_len, prefix_len + extend_len) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ), - dtype=torch.int64, - device=device, + ret.positions = torch.concat( + [ + torch.arange(prefix_len, prefix_len + extend_len, device=device) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) + ], + axis=0, ) - ret.image_inputs = batch.image_inputs - ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) + ret.extend_seq_lens = torch.tensor( + batch.extend_seq_lens, dtype=torch.int32 + ).to(device, non_blocking=True) ret.extend_prefix_lens = torch.tensor( - batch.extend_prefix_lens, device=device - ) + batch.extend_prefix_lens, dtype=torch.int32 + ).to(device, non_blocking=True) ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) ret.extend_seq_lens_cpu = batch.extend_seq_lens diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 4214a746b..c35e8edba 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator: self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} + is_required = False for penalizer in self.penalizers.values(): - penalizer.prepare_if_required() + pen_is_required = penalizer.prepare_if_required() + is_required |= pen_is_required + self.is_required = is_required - self.cumulate_input_tokens( - input_ids=[req.origin_input_ids for req in self.reqs()] - ) + if self.is_required: + self.cumulate_input_tokens( + input_ids=[req.origin_input_ids for req in self.reqs()] + ) def reqs(self): return self.batch.reqs @@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator: Args: output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. """ + if not self.is_required: + return + token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) for penalizer in self.penalizers.values(): @@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator: Returns: torch.Tensor: The logits after applying the penalizers. """ + if not self.is_required: + return + for penalizer in self.penalizers.values(): logits = penalizer.apply(logits) @@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator: indices_to_keep (typing.List[int]): List of indices to keep in the batch. indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. """ + if not self.is_required: + return + empty_indices = len(indices_to_keep) == 0 + is_required = False for penalizer in self.penalizers.values(): - if not penalizer.is_required() or empty_indices: + tmp_is_required = penalizer.is_required() + is_required = is_required or tmp_is_required + if not tmp_is_required or empty_indices: penalizer.teardown() else: # create tensor index only when it's needed @@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator: indices_to_keep=indices_to_keep, indices_tensor_to_keep=indices_tensor_to_keep, ) + self.is_required = is_required def merge(self, their: "BatchedPenalizerOrchestrator"): """ @@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator: Args: their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. """ - if self.vocab_size != their.vocab_size: - raise ValueError( - f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}" - ) + if not self.is_required and not their.is_required: + return + self.is_required |= their.is_required for Penalizer, their_penalizer in their.penalizers.items(): if Penalizer not in self.penalizers: raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") @@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC): def prepare_if_required(self): if self.is_required(): self.prepare() + return True + else: + return False def teardown(self): if self.is_prepared(): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index cc4229ff5..37dedcd17 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -48,20 +48,24 @@ class SamplingBatchInfo: disable_penalizer: bool, ): reqs = batch.reqs - with batch.input_ids.device: - temperatures = torch.tensor( + device = batch.input_ids.device + temperatures = ( + torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, - ).view(-1, 1) - top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float - ) - top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int32 - ) - min_ps = torch.tensor( - [r.sampling_params.min_p for r in reqs], dtype=torch.float ) + .view(-1, 1) + .to(device, non_blocking=True) + ) + top_ps = torch.tensor( + [r.sampling_params.top_p for r in reqs], dtype=torch.float + ).to(device, non_blocking=True) + top_ks = torch.tensor( + [r.sampling_params.top_k for r in reqs], dtype=torch.int32 + ).to(device, non_blocking=True) + min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float + ).to(device, non_blocking=True) ret = cls( temperatures=temperatures, @@ -80,7 +84,7 @@ class SamplingBatchInfo: # # While we choose not to even create the class instances if they are not required, this # could add additional complexity to the {ScheduleBatch} class, especially we need to - # handle {filter_batch()} and {merge()} cases as well. + # handle {filter_batch()} and {merge_batch()} cases as well. if disable_penalizer: ret.penalizer_orchestrator = None else: @@ -112,19 +116,20 @@ class SamplingBatchInfo: self.linear_penalties = None for penalizer in self.penalizer_orchestrator.penalizers.values(): + if not penalizer.is_prepared(): + continue + if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): - if penalizer.is_prepared(): - self.scaling_penalties = penalizer.cumulated_repetition_penalties + self.scaling_penalties = penalizer.cumulated_repetition_penalties else: - if penalizer.is_prepared(): - if self.linear_penalties is None: - bs = self.penalizer_orchestrator.batch.batch_size() - self.linear_penalties = torch.zeros( - (bs, self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - self.linear_penalties = penalizer.apply(self.linear_penalties) + if self.linear_penalties is None: + bs = self.penalizer_orchestrator.batch.batch_size() + self.linear_penalties = torch.zeros( + (bs, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self): has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) diff --git a/python/sglang/test/srt/sampling/penaltylib/utils.py b/python/sglang/test/srt/sampling/penaltylib/utils.py index b41eac32b..31667c7f0 100644 --- a/python/sglang/test/srt/sampling/penaltylib/utils.py +++ b/python/sglang/test/srt/sampling/penaltylib/utils.py @@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase): msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", ) - actual = orchestrator.apply( - torch.ones( - size=(len(case.test_subjects), self.vocab_size), - dtype=torch.float32, - device=self.device, - ) + original = torch.ones( + size=(len(case.test_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, ) + actual = orchestrator.apply(original.clone()) expected = torch.cat( tensors=[ subject.steps[0].expected_logits for subject in case.test_subjects ], ) + if actual is None: + actual = original torch.testing.assert_close( actual=actual, expected=expected, @@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase): device=self.device, ) ) + if actual_logits is None: + continue filtered_expected_logits = torch.cat( tensors=[ subject.steps[0].expected_logits @@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase): msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", ) - actual_logits = orchestrator.apply( - torch.ones( - size=(len(filtered_subjects), self.vocab_size), - dtype=torch.float32, - device=self.device, - ) + original = torch.ones( + size=(len(filtered_subjects), self.vocab_size), + dtype=torch.float32, + device=self.device, ) + actual_logits = orchestrator.apply(original.clone()) filtered_expected_logits = torch.cat( tensors=[ subject.steps[i].expected_logits for subject in filtered_subjects ], ) + if actual_logits is None: + actual_logits = original torch.testing.assert_close( actual=actual_logits, expected=filtered_expected_logits,