From fa6723f08ffe696b277c47e7bbf7a84dc4a8bd08 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 27 May 2025 12:22:59 -0700 Subject: [PATCH] Revert "fix communicator for non-dp lm head (#6662)" (#6677) --- python/sglang/srt/layers/communicator.py | 16 ++++------------ python/sglang/srt/models/qwen2_moe.py | 2 +- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index da30cbfd6..5dad481cd 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -50,18 +50,10 @@ class ScatterMode(Enum): FULL = auto() @staticmethod - def model_input_mode(): - """The scatter mode for model input data""" + def model_input_output(): + """The scatter mode for model forward pass input and output data""" return ScatterMode.TP_ATTN_FULL - @staticmethod - def model_output_mode(): - """The scatter mode for model output data""" - if global_server_args_dict["enable_dp_lm_head"]: - return ScatterMode.TP_ATTN_FULL - else: - return ScatterMode.FULL - @dataclass class _LayerModeComputationContext: @@ -103,7 +95,7 @@ class LayerScatterModes: @classmethod def _compute_layer_input_mode(cls, context: _LayerModeComputationContext): if context.layer_id == 0: - return ScatterMode.model_input_mode() + return ScatterMode.model_input_output() return cls._compute_layer_output_mode(context.previous_layer()) @classmethod @@ -134,7 +126,7 @@ class LayerScatterModes: def _compute_layer_output_mode(cls, context: _LayerModeComputationContext): mlp_mode = cls._compute_mlp_mode(context) if context.layer_id == context.num_layers - 1: - return ScatterMode.model_output_mode() + return ScatterMode.model_input_output() if mlp_mode == ScatterMode.SCATTERED: return ScatterMode.SCATTERED if mlp_mode == ScatterMode.FULL: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 961c9198a..67e72d465 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module): hidden_states, residual = model_forward_maybe_tbo( layers=self.layers, enable_tbo=True, - input_data_scatter_mode=ScatterMode.model_input_mode(), + input_data_scatter_mode=ScatterMode.model_input_output(), positions=positions, forward_batch=forward_batch, hidden_states=hidden_states,