From 02bf31ef29a31e6dc05cc16f2c8411da6fdf1b8c Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Sun, 22 Jun 2025 03:03:11 +0800 Subject: [PATCH] [fix] PD disaggregation when enable mtp and tp!=dp (#7420) --- python/sglang/srt/disaggregation/decode.py | 2 +- python/sglang/srt/layers/dp_attention.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 47447ab73..f625fe171 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin: if batch: result = self.run_batch(batch) if not delay_process: - self.prepare_mlp_sync_batch(batch, result) + self.process_batch_result(batch, result) return batch, result def get_next_disagg_decode_batch_to_run( diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 3b1a87db4..84857136a 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -310,4 +310,4 @@ def attn_tp_reduce_scatter( def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): - return get_attention_tp_group().all_gather(input_, tensor_list=output_list) + return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9886310e7..0993d9682 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module): self.layer_scatter_modes = LayerScatterModes.init_new( layer_id=layer_id, - num_layers=config.num_hidden_layers, + num_layers=1 if is_nextn else config.num_hidden_layers, is_layer_sparse=self.is_layer_sparse, is_previous_layer_sparse=is_previous_layer_sparse, ) @@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module): residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, ) -> torch.Tensor: + hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual, forward_batch )