@@ -50,18 +50,10 @@ class ScatterMode(Enum):
|
||||
FULL = auto()
|
||||
|
||||
@staticmethod
|
||||
def model_input_mode():
|
||||
"""The scatter mode for model input data"""
|
||||
def model_input_output():
|
||||
"""The scatter mode for model forward pass input and output data"""
|
||||
return ScatterMode.TP_ATTN_FULL
|
||||
|
||||
@staticmethod
|
||||
def model_output_mode():
|
||||
"""The scatter mode for model output data"""
|
||||
if global_server_args_dict["enable_dp_lm_head"]:
|
||||
return ScatterMode.TP_ATTN_FULL
|
||||
else:
|
||||
return ScatterMode.FULL
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LayerModeComputationContext:
|
||||
@@ -103,7 +95,7 @@ class LayerScatterModes:
|
||||
@classmethod
|
||||
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
||||
if context.layer_id == 0:
|
||||
return ScatterMode.model_input_mode()
|
||||
return ScatterMode.model_input_output()
|
||||
return cls._compute_layer_output_mode(context.previous_layer())
|
||||
|
||||
@classmethod
|
||||
@@ -134,7 +126,7 @@ class LayerScatterModes:
|
||||
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
||||
mlp_mode = cls._compute_mlp_mode(context)
|
||||
if context.layer_id == context.num_layers - 1:
|
||||
return ScatterMode.model_output_mode()
|
||||
return ScatterMode.model_input_output()
|
||||
if mlp_mode == ScatterMode.SCATTERED:
|
||||
return ScatterMode.SCATTERED
|
||||
if mlp_mode == ScatterMode.FULL:
|
||||
|
||||
Reference in New Issue
Block a user