Revert "fix communicator for non-dp lm head (#6662)" (#6677)

This commit is contained in:
Yineng Zhang
2025-05-27 12:22:59 -07:00
committed by GitHub
parent 673ff668f7
commit fa6723f08f
2 changed files with 5 additions and 13 deletions

View File

@@ -50,18 +50,10 @@ class ScatterMode(Enum):
FULL = auto()
@staticmethod
def model_input_mode():
"""The scatter mode for model input data"""
def model_input_output():
"""The scatter mode for model forward pass input and output data"""
return ScatterMode.TP_ATTN_FULL
@staticmethod
def model_output_mode():
"""The scatter mode for model output data"""
if global_server_args_dict["enable_dp_lm_head"]:
return ScatterMode.TP_ATTN_FULL
else:
return ScatterMode.FULL
@dataclass
class _LayerModeComputationContext:
@@ -103,7 +95,7 @@ class LayerScatterModes:
@classmethod
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
if context.layer_id == 0:
return ScatterMode.model_input_mode()
return ScatterMode.model_input_output()
return cls._compute_layer_output_mode(context.previous_layer())
@classmethod
@@ -134,7 +126,7 @@ class LayerScatterModes:
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
mlp_mode = cls._compute_mlp_mode(context)
if context.layer_id == context.num_layers - 1:
return ScatterMode.model_output_mode()
return ScatterMode.model_input_output()
if mlp_mode == ScatterMode.SCATTERED:
return ScatterMode.SCATTERED
if mlp_mode == ScatterMode.FULL:

View File

@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers,
enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_mode(),
input_data_scatter_mode=ScatterMode.model_input_output(),
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,