Utilize static dispatching for communicator (#6577)
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, Optional, Tuple
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch.distributed
|
||||
|
||||
@@ -145,6 +146,36 @@ class LayerCommunicator:
|
||||
ScatterMode.FULL: self.tp_size,
|
||||
}
|
||||
|
||||
self._context = _Context(
|
||||
process_group_sizes=self.process_group_sizes,
|
||||
attn_tp_rank=self.attn_tp_rank,
|
||||
attn_tp_size=self.attn_tp_size,
|
||||
local_attn_dp_size=self.local_attn_dp_size,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
self._communicate_simple_fn = _CommunicateSimpleFn.get_fn(
|
||||
input_mode=self.layer_scatter_modes.layer_input_mode,
|
||||
output_mode=self.layer_scatter_modes.attn_mode,
|
||||
context=self._context,
|
||||
)
|
||||
self._communicate_with_all_reduce_and_layer_norm_fn = (
|
||||
_CommunicateWithAllReduceAndLayerNormFn.get_fn(
|
||||
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
|
||||
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
|
||||
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
|
||||
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
|
||||
context=self._context,
|
||||
)
|
||||
)
|
||||
self._communicate_summable_tensor_pair_fn = (
|
||||
_CommunicateSummableTensorPairFn.get_fn(
|
||||
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
|
||||
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
|
||||
output_mode=self.layer_scatter_modes.layer_output_mode,
|
||||
context=self._context,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -160,12 +191,10 @@ class LayerCommunicator:
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = _communicate_simple(
|
||||
hidden_states = self._communicate_simple_fn(
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
input_mode=self.layer_scatter_modes.layer_input_mode,
|
||||
output_mode=self.layer_scatter_modes.attn_mode,
|
||||
context=self._compute_context(forward_batch),
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
@@ -176,16 +205,12 @@ class LayerCommunicator:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
return _communicate_with_all_reduce_and_layer_norm(
|
||||
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
|
||||
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
|
||||
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
|
||||
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
|
||||
layernorm=self.post_attention_layernorm,
|
||||
context=self._compute_context(forward_batch),
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
def postprocess_layer(
|
||||
@@ -194,58 +219,16 @@ class LayerCommunicator:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
return _communicate_summable_tensor_pair(
|
||||
return self._communicate_summable_tensor_pair_fn(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
|
||||
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
|
||||
output_mode=self.layer_scatter_modes.layer_output_mode,
|
||||
context=self._compute_context(forward_batch),
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
def _compute_context(self, forward_batch: ForwardBatch):
|
||||
return _Context(
|
||||
num_tokens_of_mode=_compute_num_tokens_of_mode(
|
||||
forward_batch,
|
||||
attn_tp_rank=self.attn_tp_rank,
|
||||
attn_tp_size=self.attn_tp_size,
|
||||
),
|
||||
process_group_sizes=self.process_group_sizes,
|
||||
attn_tp_rank=self.attn_tp_rank,
|
||||
attn_tp_size=self.attn_tp_size,
|
||||
local_attn_dp_size=self.local_attn_dp_size,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
|
||||
def _compute_num_tokens_of_mode(
|
||||
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
|
||||
):
|
||||
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
|
||||
return {
|
||||
ScatterMode.SCATTERED: _torch_tensor_split_len(
|
||||
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
|
||||
),
|
||||
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
|
||||
ScatterMode.FULL: (
|
||||
forward_batch.gathered_buffer.shape[0]
|
||||
if global_server_args_dict["enable_dp_attention"]
|
||||
else forward_batch.input_ids.shape[0]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
|
||||
if output_index < int(tensor_len % n):
|
||||
return int(tensor_len / n) + 1
|
||||
else:
|
||||
return int(tensor_len / n)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Context:
|
||||
num_tokens_of_mode: Dict["ScatterMode", int]
|
||||
process_group_sizes: Dict["ScatterMode", int]
|
||||
attn_tp_rank: int
|
||||
attn_tp_size: int
|
||||
@@ -255,75 +238,63 @@ class _Context:
|
||||
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
|
||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||
|
||||
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
|
||||
if x is None:
|
||||
return
|
||||
|
||||
actual_num_tokens = x.shape[0]
|
||||
expect_num_tokens = self.num_tokens_of_mode[mode]
|
||||
assert (
|
||||
actual_num_tokens == expect_num_tokens
|
||||
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
|
||||
return x
|
||||
|
||||
def check_shapes(
|
||||
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
return tuple(
|
||||
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
|
||||
)
|
||||
|
||||
|
||||
def _communicate_simple(
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_mode: ScatterMode,
|
||||
output_mode: ScatterMode,
|
||||
context: _Context,
|
||||
) -> torch.Tensor:
|
||||
def _inner():
|
||||
nonlocal hidden_states
|
||||
|
||||
class _CommunicateSimpleFn:
|
||||
@staticmethod
|
||||
def get_fn(
|
||||
input_mode: ScatterMode,
|
||||
output_mode: ScatterMode,
|
||||
context: _Context,
|
||||
):
|
||||
if context.is_same_group_size(input_mode, output_mode):
|
||||
return hidden_states
|
||||
return _CommunicateSimpleFn._trivial
|
||||
|
||||
if (input_mode == ScatterMode.SCATTERED) and (
|
||||
output_mode == ScatterMode.TP_ATTN_FULL
|
||||
):
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states
|
||||
return _CommunicateSimpleFn._scattered_to_tp_attn_full
|
||||
|
||||
raise NotImplementedError(f"{input_mode=} {output_mode=}")
|
||||
|
||||
context.check_shape(hidden_states, input_mode)
|
||||
return context.check_shape(_inner(), output_mode)
|
||||
@staticmethod
|
||||
def _trivial(
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: _Context,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def _scattered_to_tp_attn_full(
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: _Context,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _communicate_with_all_reduce_and_layer_norm(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
hidden_states_input_mode: ScatterMode,
|
||||
residual_input_mode: ScatterMode,
|
||||
hidden_states_output_mode: ScatterMode,
|
||||
residual_output_mode: ScatterMode,
|
||||
forward_batch: ForwardBatch,
|
||||
layernorm: torch.nn.Module,
|
||||
context: _Context,
|
||||
):
|
||||
class _CommunicateWithAllReduceAndLayerNormFn:
|
||||
"""Besides communication, needs to
|
||||
1. All reduce in tp_attn_group on hidden_states
|
||||
2. Apply layer norm
|
||||
"""
|
||||
|
||||
def _inner():
|
||||
nonlocal hidden_states, residual
|
||||
@staticmethod
|
||||
def get_fn(
|
||||
hidden_states_input_mode: ScatterMode,
|
||||
residual_input_mode: ScatterMode,
|
||||
hidden_states_output_mode: ScatterMode,
|
||||
residual_output_mode: ScatterMode,
|
||||
context: _Context,
|
||||
):
|
||||
|
||||
if (
|
||||
context.is_same_group_size(
|
||||
@@ -332,10 +303,7 @@ def _communicate_with_all_reduce_and_layer_norm(
|
||||
and context.is_same_group_size(residual_input_mode, residual_output_mode)
|
||||
and context.attn_tp_size == 1
|
||||
):
|
||||
# TODO move these `if shape != 0` into LayerNorm itself
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
return _CommunicateWithAllReduceAndLayerNormFn._simple
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
@@ -343,21 +311,7 @@ def _communicate_with_all_reduce_and_layer_norm(
|
||||
and (hidden_states_output_mode == ScatterMode.FULL)
|
||||
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
||||
):
|
||||
if context.local_attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
return _CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
@@ -367,85 +321,147 @@ def _communicate_with_all_reduce_and_layer_norm(
|
||||
and (hidden_states_output_mode == ScatterMode.SCATTERED)
|
||||
and (residual_output_mode == ScatterMode.SCATTERED)
|
||||
):
|
||||
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||
hidden_states = tensor_list[context.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
||||
residual = residual.tensor_split(context.attn_tp_size)[
|
||||
context.attn_tp_rank
|
||||
]
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
return partial(
|
||||
_CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,
|
||||
residual_input_mode=residual_input_mode,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
|
||||
)
|
||||
|
||||
context.check_shapes(
|
||||
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
||||
)
|
||||
return context.check_shapes(
|
||||
_inner(), (hidden_states_output_mode, residual_output_mode)
|
||||
)
|
||||
@staticmethod
|
||||
def _simple(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layernorm: torch.nn.Module,
|
||||
context: _Context,
|
||||
):
|
||||
# TODO move these `if shape != 0` into LayerNorm itself
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _gather_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layernorm: torch.nn.Module,
|
||||
context: _Context,
|
||||
):
|
||||
if context.local_attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _scatter_hidden_states_and_residual(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layernorm: torch.nn.Module,
|
||||
context: _Context,
|
||||
*,
|
||||
residual_input_mode,
|
||||
):
|
||||
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||
hidden_states = tensor_list[context.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
||||
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
def _communicate_summable_tensor_pair(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
hidden_states_input_mode: ScatterMode,
|
||||
residual_input_mode: ScatterMode,
|
||||
output_mode: ScatterMode,
|
||||
context: _Context,
|
||||
):
|
||||
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
|
||||
class _CommunicateSummableTensorPairFn:
|
||||
|
||||
def _inner():
|
||||
nonlocal hidden_states, residual
|
||||
@staticmethod
|
||||
def get_fn(
|
||||
hidden_states_input_mode: ScatterMode,
|
||||
residual_input_mode: ScatterMode,
|
||||
output_mode: ScatterMode,
|
||||
context: _Context,
|
||||
):
|
||||
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
|
||||
|
||||
if context.is_same_group_size(
|
||||
hidden_states_input_mode, output_mode
|
||||
) and context.is_same_group_size(residual_input_mode, output_mode):
|
||||
return hidden_states, residual
|
||||
return _CommunicateSummableTensorPairFn._trivial
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.FULL)
|
||||
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
||||
):
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
return hidden_states, residual
|
||||
return _CommunicateSummableTensorPairFn._scatter_hidden_states
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.SCATTERED)
|
||||
and (residual_input_mode == ScatterMode.SCATTERED)
|
||||
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
||||
):
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states, residual
|
||||
return _CommunicateSummableTensorPairFn._gather
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
||||
)
|
||||
|
||||
context.check_shapes(
|
||||
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
||||
)
|
||||
return context.check_shapes(_inner(), (output_mode, output_mode))
|
||||
@staticmethod
|
||||
def _trivial(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: _Context,
|
||||
):
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _scatter_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: _Context,
|
||||
):
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _gather(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: _Context,
|
||||
):
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(context.attn_tp_size)),
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user