diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5dad481cd..da30cbfd6 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -50,10 +50,18 @@ class ScatterMode(Enum): FULL = auto() @staticmethod - def model_input_output(): - """The scatter mode for model forward pass input and output data""" + def model_input_mode(): + """The scatter mode for model input data""" return ScatterMode.TP_ATTN_FULL + @staticmethod + def model_output_mode(): + """The scatter mode for model output data""" + if global_server_args_dict["enable_dp_lm_head"]: + return ScatterMode.TP_ATTN_FULL + else: + return ScatterMode.FULL + @dataclass class _LayerModeComputationContext: @@ -95,7 +103,7 @@ class LayerScatterModes: @classmethod def _compute_layer_input_mode(cls, context: _LayerModeComputationContext): if context.layer_id == 0: - return ScatterMode.model_input_output() + return ScatterMode.model_input_mode() return cls._compute_layer_output_mode(context.previous_layer()) @classmethod @@ -126,7 +134,7 @@ class LayerScatterModes: def _compute_layer_output_mode(cls, context: _LayerModeComputationContext): mlp_mode = cls._compute_mlp_mode(context) if context.layer_id == context.num_layers - 1: - return ScatterMode.model_input_output() + return ScatterMode.model_output_mode() if mlp_mode == ScatterMode.SCATTERED: return ScatterMode.SCATTERED if mlp_mode == ScatterMode.FULL: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 67e72d465..961c9198a 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module): hidden_states, residual = model_forward_maybe_tbo( layers=self.layers, enable_tbo=True, - input_data_scatter_mode=ScatterMode.model_input_output(), + input_data_scatter_mode=ScatterMode.model_input_mode(), positions=positions, forward_batch=forward_batch, hidden_states=hidden_states,