Use reduce scatter for DP (#8539)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather_into_tensor,
|
||||
attn_tp_reduce_scatter_tensor,
|
||||
dp_gather_partial,
|
||||
dp_reduce_scatter_tensor,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
||||
layer_scatter_modes: LayerScatterModes,
|
||||
input_layernorm: torch.nn.Module,
|
||||
post_attention_layernorm: torch.nn.Module,
|
||||
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
self.layer_scatter_modes = layer_scatter_modes
|
||||
self.input_layernorm = input_layernorm
|
||||
self.post_attention_layernorm = post_attention_layernorm
|
||||
self.allow_reduce_scatter = allow_reduce_scatter
|
||||
|
||||
self._context = CommunicateContext.init_new()
|
||||
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
||||
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
context=self._context,
|
||||
allow_reduce_scatter=self.allow_reduce_scatter,
|
||||
)
|
||||
|
||||
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
||||
return (
|
||||
self.allow_reduce_scatter
|
||||
and self._communicate_summable_tensor_pair_fn
|
||||
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
||||
and forward_batch.dp_padding_mode.is_max_len()
|
||||
)
|
||||
|
||||
|
||||
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
**kwargs,
|
||||
):
|
||||
return hidden_states, residual
|
||||
|
||||
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||
else:
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
|
||||
Reference in New Issue
Block a user