fix communicator for non-dp lm head (#6662)
This commit is contained in:
@@ -50,10 +50,18 @@ class ScatterMode(Enum):
|
|||||||
FULL = auto()
|
FULL = auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def model_input_output():
|
def model_input_mode():
|
||||||
"""The scatter mode for model forward pass input and output data"""
|
"""The scatter mode for model input data"""
|
||||||
return ScatterMode.TP_ATTN_FULL
|
return ScatterMode.TP_ATTN_FULL
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_output_mode():
|
||||||
|
"""The scatter mode for model output data"""
|
||||||
|
if global_server_args_dict["enable_dp_lm_head"]:
|
||||||
|
return ScatterMode.TP_ATTN_FULL
|
||||||
|
else:
|
||||||
|
return ScatterMode.FULL
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _LayerModeComputationContext:
|
class _LayerModeComputationContext:
|
||||||
@@ -95,7 +103,7 @@ class LayerScatterModes:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
||||||
if context.layer_id == 0:
|
if context.layer_id == 0:
|
||||||
return ScatterMode.model_input_output()
|
return ScatterMode.model_input_mode()
|
||||||
return cls._compute_layer_output_mode(context.previous_layer())
|
return cls._compute_layer_output_mode(context.previous_layer())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -126,7 +134,7 @@ class LayerScatterModes:
|
|||||||
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
||||||
mlp_mode = cls._compute_mlp_mode(context)
|
mlp_mode = cls._compute_mlp_mode(context)
|
||||||
if context.layer_id == context.num_layers - 1:
|
if context.layer_id == context.num_layers - 1:
|
||||||
return ScatterMode.model_input_output()
|
return ScatterMode.model_output_mode()
|
||||||
if mlp_mode == ScatterMode.SCATTERED:
|
if mlp_mode == ScatterMode.SCATTERED:
|
||||||
return ScatterMode.SCATTERED
|
return ScatterMode.SCATTERED
|
||||||
if mlp_mode == ScatterMode.FULL:
|
if mlp_mode == ScatterMode.FULL:
|
||||||
|
|||||||
@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states, residual = model_forward_maybe_tbo(
|
hidden_states, residual = model_forward_maybe_tbo(
|
||||||
layers=self.layers,
|
layers=self.layers,
|
||||||
enable_tbo=True,
|
enable_tbo=True,
|
||||||
input_data_scatter_mode=ScatterMode.model_input_output(),
|
input_data_scatter_mode=ScatterMode.model_input_mode(),
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
Reference in New Issue
Block a user