Implement gather before attn (#6378)
This commit is contained in:
@@ -226,13 +226,13 @@ class LayerCommunicator:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommunicateContext:
|
class CommunicateContext:
|
||||||
process_group_sizes: Dict["ScatterMode", int]
|
process_group_sizes: Dict[ScatterMode, int]
|
||||||
attn_tp_rank: int
|
attn_tp_rank: int
|
||||||
attn_tp_size: int
|
attn_tp_size: int
|
||||||
local_attn_dp_size: int
|
local_attn_dp_size: int
|
||||||
tp_size: int
|
tp_size: int
|
||||||
|
|
||||||
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -244,6 +244,7 @@ class CommunicateContext:
|
|||||||
process_group_sizes = {
|
process_group_sizes = {
|
||||||
ScatterMode.SCATTERED: 1,
|
ScatterMode.SCATTERED: 1,
|
||||||
ScatterMode.TP_ATTN_FULL: attn_tp_size,
|
ScatterMode.TP_ATTN_FULL: attn_tp_size,
|
||||||
|
# TODO: support --moe-dense-tp-size > 1
|
||||||
ScatterMode.FULL: tp_size,
|
ScatterMode.FULL: tp_size,
|
||||||
}
|
}
|
||||||
return cls(
|
return cls(
|
||||||
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||||
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
and (
|
||||||
|
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
||||||
|
)
|
||||||
and (hidden_states_output_mode == ScatterMode.FULL)
|
and (hidden_states_output_mode == ScatterMode.FULL)
|
||||||
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
||||||
):
|
):
|
||||||
return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
|
return partial(
|
||||||
|
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
|
||||||
|
residual_input_mode=residual_input_mode,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||||
@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _gather_hidden_states(
|
def _gather_hidden_states_and_residual(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
layernorm: torch.nn.Module,
|
layernorm: torch.nn.Module,
|
||||||
context: CommunicateContext,
|
context: CommunicateContext,
|
||||||
|
*,
|
||||||
|
residual_input_mode,
|
||||||
):
|
):
|
||||||
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||||
|
residual, local_residual = (
|
||||||
|
forward_batch.gathered_buffer[
|
||||||
|
: forward_batch.input_ids.shape[0]
|
||||||
|
].clone(),
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
attn_tp_all_gather(
|
||||||
|
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||||
|
)
|
||||||
if context.local_attn_dp_size != 1:
|
if context.local_attn_dp_size != 1:
|
||||||
if context.attn_tp_rank == 0:
|
if context.attn_tp_rank == 0:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
|||||||
Reference in New Issue
Block a user