Implement gather before attn (#6378)
This commit is contained in:
@@ -226,13 +226,13 @@ class LayerCommunicator:
|
||||
|
||||
@dataclass
|
||||
class CommunicateContext:
|
||||
process_group_sizes: Dict["ScatterMode", int]
|
||||
process_group_sizes: Dict[ScatterMode, int]
|
||||
attn_tp_rank: int
|
||||
attn_tp_size: int
|
||||
local_attn_dp_size: int
|
||||
tp_size: int
|
||||
|
||||
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
|
||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||
|
||||
@classmethod
|
||||
@@ -244,6 +244,7 @@ class CommunicateContext:
|
||||
process_group_sizes = {
|
||||
ScatterMode.SCATTERED: 1,
|
||||
ScatterMode.TP_ATTN_FULL: attn_tp_size,
|
||||
# TODO: support --moe-dense-tp-size > 1
|
||||
ScatterMode.FULL: tp_size,
|
||||
}
|
||||
return cls(
|
||||
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (
|
||||
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
||||
)
|
||||
and (hidden_states_output_mode == ScatterMode.FULL)
|
||||
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
||||
):
|
||||
return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
|
||||
return partial(
|
||||
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
|
||||
residual_input_mode=residual_input_mode,
|
||||
)
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _gather_hidden_states(
|
||||
def _gather_hidden_states_and_residual(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layernorm: torch.nn.Module,
|
||||
context: CommunicateContext,
|
||||
*,
|
||||
residual_input_mode,
|
||||
):
|
||||
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[
|
||||
: forward_batch.input_ids.shape[0]
|
||||
].clone(),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||
)
|
||||
if context.local_attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
|
||||
Reference in New Issue
Block a user