Utilize static dispatching for communicator (#6577)

This commit is contained in:
fzyzcjy
2025-05-25 08:34:35 +08:00
committed by GitHub
parent b2388433be
commit f456037396

View File

@@ -14,7 +14,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Optional, Tuple from functools import partial
from typing import Dict, Optional
import torch.distributed import torch.distributed
@@ -145,6 +146,36 @@ class LayerCommunicator:
ScatterMode.FULL: self.tp_size, ScatterMode.FULL: self.tp_size,
} }
self._context = _Context(
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
)
self._communicate_simple_fn = _CommunicateSimpleFn.get_fn(
input_mode=self.layer_scatter_modes.layer_input_mode,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._context,
)
self._communicate_with_all_reduce_and_layer_norm_fn = (
_CommunicateWithAllReduceAndLayerNormFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
context=self._context,
)
)
self._communicate_summable_tensor_pair_fn = (
_CommunicateSummableTensorPairFn.get_fn(
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._context,
)
)
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -160,12 +191,10 @@ class LayerCommunicator:
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = _communicate_simple( hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
input_mode=self.layer_scatter_modes.layer_input_mode, context=self._context,
output_mode=self.layer_scatter_modes.attn_mode,
context=self._compute_context(forward_batch),
) )
return hidden_states, residual return hidden_states, residual
@@ -176,16 +205,12 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_with_all_reduce_and_layer_norm( return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
layernorm=self.post_attention_layernorm, layernorm=self.post_attention_layernorm,
context=self._compute_context(forward_batch), context=self._context,
) )
def postprocess_layer( def postprocess_layer(
@@ -194,58 +219,16 @@ class LayerCommunicator:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
return _communicate_summable_tensor_pair( return self._communicate_summable_tensor_pair_fn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, context=self._context,
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
output_mode=self.layer_scatter_modes.layer_output_mode,
context=self._compute_context(forward_batch),
) )
def _compute_context(self, forward_batch: ForwardBatch):
return _Context(
num_tokens_of_mode=_compute_num_tokens_of_mode(
forward_batch,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
),
process_group_sizes=self.process_group_sizes,
attn_tp_rank=self.attn_tp_rank,
attn_tp_size=self.attn_tp_size,
local_attn_dp_size=self.local_attn_dp_size,
tp_size=self.tp_size,
)
def _compute_num_tokens_of_mode(
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
):
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
return {
ScatterMode.SCATTERED: _torch_tensor_split_len(
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
),
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
ScatterMode.FULL: (
forward_batch.gathered_buffer.shape[0]
if global_server_args_dict["enable_dp_attention"]
else forward_batch.input_ids.shape[0]
),
}
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
if output_index < int(tensor_len % n):
return int(tensor_len / n) + 1
else:
return int(tensor_len / n)
@dataclass @dataclass
class _Context: class _Context:
num_tokens_of_mode: Dict["ScatterMode", int]
process_group_sizes: Dict["ScatterMode", int] process_group_sizes: Dict["ScatterMode", int]
attn_tp_rank: int attn_tp_rank: int
attn_tp_size: int attn_tp_size: int
@@ -255,41 +238,38 @@ class _Context:
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"): def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
return self.process_group_sizes[a] == self.process_group_sizes[b] return self.process_group_sizes[a] == self.process_group_sizes[b]
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
if x is None:
return
actual_num_tokens = x.shape[0] class _CommunicateSimpleFn:
expect_num_tokens = self.num_tokens_of_mode[mode] @staticmethod
assert ( def get_fn(
actual_num_tokens == expect_num_tokens
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
return x
def check_shapes(
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
) -> Tuple[torch.Tensor, ...]:
return tuple(
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
)
def _communicate_simple(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
input_mode: ScatterMode, input_mode: ScatterMode,
output_mode: ScatterMode, output_mode: ScatterMode,
context: _Context, context: _Context,
) -> torch.Tensor: ):
def _inner():
nonlocal hidden_states
if context.is_same_group_size(input_mode, output_mode): if context.is_same_group_size(input_mode, output_mode):
return hidden_states return _CommunicateSimpleFn._trivial
if (input_mode == ScatterMode.SCATTERED) and ( if (input_mode == ScatterMode.SCATTERED) and (
output_mode == ScatterMode.TP_ATTN_FULL output_mode == ScatterMode.TP_ATTN_FULL
): ):
return _CommunicateSimpleFn._scattered_to_tp_attn_full
raise NotImplementedError(f"{input_mode=} {output_mode=}")
@staticmethod
def _trivial(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
) -> torch.Tensor:
return hidden_states
@staticmethod
def _scattered_to_tp_attn_full(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
) -> torch.Tensor:
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
@@ -300,30 +280,21 @@ def _communicate_simple(
) )
return hidden_states return hidden_states
raise NotImplementedError(f"{input_mode=} {output_mode=}")
context.check_shape(hidden_states, input_mode) class _CommunicateWithAllReduceAndLayerNormFn:
return context.check_shape(_inner(), output_mode)
def _communicate_with_all_reduce_and_layer_norm(
hidden_states: torch.Tensor,
residual: torch.Tensor,
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode,
hidden_states_output_mode: ScatterMode,
residual_output_mode: ScatterMode,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
):
"""Besides communication, needs to """Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states 1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm 2. Apply layer norm
""" """
def _inner(): @staticmethod
nonlocal hidden_states, residual def get_fn(
hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode,
hidden_states_output_mode: ScatterMode,
residual_output_mode: ScatterMode,
context: _Context,
):
if ( if (
context.is_same_group_size( context.is_same_group_size(
@@ -332,16 +303,53 @@ def _communicate_with_all_reduce_and_layer_norm(
and context.is_same_group_size(residual_input_mode, residual_output_mode) and context.is_same_group_size(residual_input_mode, residual_output_mode)
and context.attn_tp_size == 1 and context.attn_tp_size == 1
): ):
# TODO move these `if shape != 0` into LayerNorm itself return _CommunicateWithAllReduceAndLayerNormFn._simple
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
if ( if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (hidden_states_output_mode == ScatterMode.FULL) and (hidden_states_output_mode == ScatterMode.FULL)
and (residual_output_mode == ScatterMode.TP_ATTN_FULL) and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
)
and (hidden_states_output_mode == ScatterMode.SCATTERED)
and (residual_output_mode == ScatterMode.SCATTERED)
):
return partial(
_CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,
residual_input_mode=residual_input_mode,
)
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
)
@staticmethod
def _simple(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
):
# TODO move these `if shape != 0` into LayerNorm itself
if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
@staticmethod
def _gather_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
layernorm: torch.nn.Module,
context: _Context,
): ):
if context.local_attn_dp_size != 1: if context.local_attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
@@ -359,41 +367,30 @@ def _communicate_with_all_reduce_and_layer_norm(
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
if ( @staticmethod
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) def _scatter_hidden_states_and_residual(
and ( hidden_states: torch.Tensor,
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL] residual: torch.Tensor,
) forward_batch: ForwardBatch,
and (hidden_states_output_mode == ScatterMode.SCATTERED) layernorm: torch.nn.Module,
and (residual_output_mode == ScatterMode.SCATTERED) context: _Context,
*,
residual_input_mode,
): ):
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank] hidden_states = tensor_list[context.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list) attn_tp_reduce_scatter(hidden_states, tensor_list)
if residual_input_mode == ScatterMode.TP_ATTN_FULL: if residual_input_mode == ScatterMode.TP_ATTN_FULL:
residual = residual.tensor_split(context.attn_tp_size)[ residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
context.attn_tp_rank
]
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states, residual = layernorm(hidden_states, residual) hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual return hidden_states, residual
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
)
context.check_shapes( class _CommunicateSummableTensorPairFn:
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
)
return context.check_shapes(
_inner(), (hidden_states_output_mode, residual_output_mode)
)
@staticmethod
def _communicate_summable_tensor_pair( def get_fn(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
hidden_states_input_mode: ScatterMode, hidden_states_input_mode: ScatterMode,
residual_input_mode: ScatterMode, residual_input_mode: ScatterMode,
output_mode: ScatterMode, output_mode: ScatterMode,
@@ -401,18 +398,44 @@ def _communicate_summable_tensor_pair(
): ):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed.""" """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
def _inner():
nonlocal hidden_states, residual
if context.is_same_group_size( if context.is_same_group_size(
hidden_states_input_mode, output_mode hidden_states_input_mode, output_mode
) and context.is_same_group_size(residual_input_mode, output_mode): ) and context.is_same_group_size(residual_input_mode, output_mode):
return hidden_states, residual return _CommunicateSummableTensorPairFn._trivial
if ( if (
(hidden_states_input_mode == ScatterMode.FULL) (hidden_states_input_mode == ScatterMode.FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateSummableTensorPairFn._scatter_hidden_states
if (
(hidden_states_input_mode == ScatterMode.SCATTERED)
and (residual_input_mode == ScatterMode.SCATTERED)
and (output_mode == ScatterMode.TP_ATTN_FULL)
):
return _CommunicateSummableTensorPairFn._gather
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
@staticmethod
def _trivial(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
):
return hidden_states, residual
@staticmethod
def _scatter_hidden_states(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
): ):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
@@ -424,10 +447,12 @@ def _communicate_summable_tensor_pair(
dp_scatter(hidden_states, global_hidden_states, forward_batch) dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual return hidden_states, residual
if ( @staticmethod
(hidden_states_input_mode == ScatterMode.SCATTERED) def _gather(
and (residual_input_mode == ScatterMode.SCATTERED) hidden_states: torch.Tensor,
and (output_mode == ScatterMode.TP_ATTN_FULL) residual: torch.Tensor,
forward_batch: ForwardBatch,
context: _Context,
): ):
hidden_states += residual hidden_states += residual
residual = None residual = None
@@ -440,12 +465,3 @@ def _communicate_summable_tensor_pair(
local_hidden_states, local_hidden_states,
) )
return hidden_states, residual return hidden_states, residual
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
context.check_shapes(
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
)
return context.check_shapes(_inner(), (output_mode, output_mode))