diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 29f18f0ef..b4fc4d7a7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -454,6 +454,7 @@ class DeepseekV2MoE(nn.Module): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 6b0241f40..b417de7ce 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -110,7 +110,7 @@ def compute_split_indices_for_cuda_graph_replay( class TboCudaGraphRunnerPlugin: def __init__(self): - pass # TODO add logic here + self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): if not global_server_args_dict["enable_two_batch_overlap"]: @@ -124,7 +124,14 @@ class TboCudaGraphRunnerPlugin: # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" - TboForwardBatchPreparer.prepare(batch) + self._tbo_children_num_token_non_padded[...] = ( + TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch) + ) + + TboForwardBatchPreparer.prepare_raw( + batch, + tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded, + ) def replay_prepare( self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int @@ -132,7 +139,20 @@ class TboCudaGraphRunnerPlugin: if not global_server_args_dict["enable_two_batch_overlap"]: return - pass # TODO add logic here + tbo_split_seq_index, tbo_split_token_index = ( + compute_split_indices_for_cuda_graph_replay( + forward_mode=forward_mode, + # TODO support bs!=num_tokens + cuda_graph_num_tokens=bs, + ) + ) + + self._tbo_children_num_token_non_padded[...] = ( + TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw( + tbo_split_token_index=tbo_split_token_index, + num_token_non_padded=num_token_non_padded, + ) + ) class TboDPAttentionPreparer: @@ -207,16 +227,23 @@ class TboDPAttentionPreparer: class TboForwardBatchPreparer: @classmethod def prepare(cls, batch: ForwardBatch): - from sglang.srt.layers.attention.tbo_backend import TboAttnBackend - if batch.tbo_split_seq_index is None: return - tbo_split_token_index = compute_split_token_index( - split_seq_index=batch.tbo_split_seq_index, - forward_mode=batch.forward_mode, - extend_seq_lens=batch.extend_seq_lens_cpu, + tbo_children_num_token_non_padded = ( + cls.compute_tbo_children_num_token_non_padded(batch) ) + cls.prepare_raw( + batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded + ) + + @classmethod + def prepare_raw( + cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor + ): + from sglang.srt.layers.attention.tbo_backend import TboAttnBackend + + tbo_split_token_index = cls._compute_split_token_index(batch) if _tbo_debug: logger.info( @@ -229,6 +256,10 @@ class TboForwardBatchPreparer: assert isinstance(batch.attn_backend, TboAttnBackend) attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children + [out_num_token_non_padded_a, out_num_token_non_padded_b] = ( + tbo_children_num_token_non_padded + ) + child_a = cls.filter_batch( batch, start_token_index=0, @@ -236,6 +267,7 @@ class TboForwardBatchPreparer: start_seq_index=0, end_seq_index=batch.tbo_split_seq_index, output_attn_backend=attn_backend_child_a, + out_num_token_non_padded=out_num_token_non_padded_a, ) child_b = cls.filter_batch( batch, @@ -244,6 +276,7 @@ class TboForwardBatchPreparer: start_seq_index=batch.tbo_split_seq_index, end_seq_index=batch.batch_size, output_attn_backend=attn_backend_child_b, + out_num_token_non_padded=out_num_token_non_padded_b, ) assert batch.tbo_children is None @@ -259,9 +292,8 @@ class TboForwardBatchPreparer: start_seq_index: int, end_seq_index: int, output_attn_backend: AttentionBackend, + out_num_token_non_padded: torch.Tensor, ): - from sglang.srt.managers.schedule_batch import global_server_args_dict - num_tokens = batch.input_ids.shape[0] num_seqs = batch.batch_size @@ -342,6 +374,7 @@ class TboForwardBatchPreparer: ), extend_num_tokens=extend_num_tokens, attn_backend=output_attn_backend, + num_token_non_padded=out_num_token_non_padded, tbo_split_seq_index=None, tbo_parent_token_range=(start_token_index, end_token_index), tbo_children=None, @@ -357,7 +390,6 @@ class TboForwardBatchPreparer: top_p_normalized_logprobs=False, top_p=None, mm_inputs=None, - num_token_non_padded=None, ) ) @@ -372,6 +404,32 @@ class TboForwardBatchPreparer: return ForwardBatch(**output_dict) + @classmethod + def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch): + return cls.compute_tbo_children_num_token_non_padded_raw( + tbo_split_token_index=cls._compute_split_token_index(batch), + num_token_non_padded=len(batch.input_ids), + ) + + @classmethod + def compute_tbo_children_num_token_non_padded_raw( + cls, tbo_split_token_index: int, num_token_non_padded: int + ): + # TODO we may make padding on both sub-batches to make it slightly more balanced + value_a = min(tbo_split_token_index, num_token_non_padded) + value_b = max(0, num_token_non_padded - tbo_split_token_index) + return torch.tensor([value_a, value_b], dtype=torch.int32).to( + device=global_server_args_dict["device"], non_blocking=True + ) + + @classmethod + def _compute_split_token_index(cls, batch: ForwardBatch): + return compute_split_token_index( + split_seq_index=batch.tbo_split_seq_index, + forward_mode=batch.forward_mode, + extend_seq_lens=batch.extend_seq_lens_cpu, + ) + def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): if forward_mode.is_extend():