@@ -50,18 +50,10 @@ class ScatterMode(Enum):
|
||||
FULL = auto()
|
||||
|
||||
@staticmethod
|
||||
def model_input_mode():
|
||||
"""The scatter mode for model input data"""
|
||||
def model_input_output():
|
||||
"""The scatter mode for model forward pass input and output data"""
|
||||
return ScatterMode.TP_ATTN_FULL
|
||||
|
||||
@staticmethod
|
||||
def model_output_mode():
|
||||
"""The scatter mode for model output data"""
|
||||
if global_server_args_dict["enable_dp_lm_head"]:
|
||||
return ScatterMode.TP_ATTN_FULL
|
||||
else:
|
||||
return ScatterMode.FULL
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LayerModeComputationContext:
|
||||
@@ -103,7 +95,7 @@ class LayerScatterModes:
|
||||
@classmethod
|
||||
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
||||
if context.layer_id == 0:
|
||||
return ScatterMode.model_input_mode()
|
||||
return ScatterMode.model_input_output()
|
||||
return cls._compute_layer_output_mode(context.previous_layer())
|
||||
|
||||
@classmethod
|
||||
@@ -134,7 +126,7 @@ class LayerScatterModes:
|
||||
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
||||
mlp_mode = cls._compute_mlp_mode(context)
|
||||
if context.layer_id == context.num_layers - 1:
|
||||
return ScatterMode.model_output_mode()
|
||||
return ScatterMode.model_input_output()
|
||||
if mlp_mode == ScatterMode.SCATTERED:
|
||||
return ScatterMode.SCATTERED
|
||||
if mlp_mode == ScatterMode.FULL:
|
||||
|
||||
@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers,
|
||||
enable_tbo=True,
|
||||
input_data_scatter_mode=ScatterMode.model_input_mode(),
|
||||
input_data_scatter_mode=ScatterMode.model_input_output(),
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user