Use reduce scatter for DP (#8539)

This commit is contained in:
Trevor Morris
2025-08-06 16:21:26 -07:00
committed by GitHub
parent 92cc32d9fc
commit c0e84297c2
6 changed files with 73 additions and 18 deletions

View File

@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_reduce_scatter_tensor,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
@@ -149,10 +150,13 @@ class LayerCommunicator:
layer_scatter_modes: LayerScatterModes,
input_layernorm: torch.nn.Module,
post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
):
self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter
self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -239,6 +243,15 @@ class LayerCommunicator:
residual=residual,
forward_batch=forward_batch,
context=self._context,
allow_reduce_scatter=self.allow_reduce_scatter,
)
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
return (
self.allow_reduce_scatter
and self._communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn._scatter_hidden_states
and forward_batch.dp_padding_mode.is_max_len()
)
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
return hidden_states, residual
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
allow_reduce_scatter: bool = False,
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
@staticmethod
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
hidden_states += residual
residual = None