Use reduce scatter for DP (#8539)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather_into_tensor,
|
||||
attn_tp_reduce_scatter_tensor,
|
||||
dp_gather_partial,
|
||||
dp_reduce_scatter_tensor,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
||||
layer_scatter_modes: LayerScatterModes,
|
||||
input_layernorm: torch.nn.Module,
|
||||
post_attention_layernorm: torch.nn.Module,
|
||||
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
self.layer_scatter_modes = layer_scatter_modes
|
||||
self.input_layernorm = input_layernorm
|
||||
self.post_attention_layernorm = post_attention_layernorm
|
||||
self.allow_reduce_scatter = allow_reduce_scatter
|
||||
|
||||
self._context = CommunicateContext.init_new()
|
||||
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
||||
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
context=self._context,
|
||||
allow_reduce_scatter=self.allow_reduce_scatter,
|
||||
)
|
||||
|
||||
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
||||
return (
|
||||
self.allow_reduce_scatter
|
||||
and self._communicate_summable_tensor_pair_fn
|
||||
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
||||
and forward_batch.dp_padding_mode.is_max_len()
|
||||
)
|
||||
|
||||
|
||||
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
**kwargs,
|
||||
):
|
||||
return hidden_states, residual
|
||||
|
||||
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
allow_reduce_scatter: bool = False,
|
||||
):
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||
else:
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
|
||||
@@ -12,6 +12,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -355,6 +356,17 @@ def dp_scatter(
|
||||
)
|
||||
|
||||
|
||||
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
||||
get_tp_group().reduce_scatter_tensor(output, input)
|
||||
else:
|
||||
scattered_local_tokens = input.tensor_split(
|
||||
get_tensor_model_parallel_world_size()
|
||||
)[get_tensor_model_parallel_rank()]
|
||||
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
||||
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
||||
|
||||
|
||||
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
||||
|
||||
|
||||
@@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase):
|
||||
# It does not support additional parameters.
|
||||
param.load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def forward(self, input_, can_fuse_mlp_allreduce=False):
|
||||
def forward(self, input_, skip_all_reduce=False):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
@@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase):
|
||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
sm.tag(output_parallel)
|
||||
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
@@ -628,8 +628,10 @@ class ForwardBatch:
|
||||
self.dp_padding_mode = dp_padding_mode
|
||||
|
||||
if dp_padding_mode.is_max_len():
|
||||
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
|
||||
# where transferred tokens should be padded to the same length.
|
||||
# when DP gather mode is all gather, we will use
|
||||
# all_gather_into_tensor to gather hidden states, where transferred
|
||||
# tokens should be padded to the same length. We will also use
|
||||
# reduce-scatter instead of all-reduce after MLP.
|
||||
max_num_tokens = max(global_num_tokens)
|
||||
global_num_tokens = [max_num_tokens] * sync_group_size
|
||||
buffer_len = max_num_tokens * sync_group_size
|
||||
|
||||
@@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
forward_batch=None,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
):
|
||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
||||
x, _ = self.down_proj(
|
||||
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not self._enable_deepep_moe:
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||
@@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
|
||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||
):
|
||||
return self.forward_normal_dual_stream(
|
||||
hidden_states, can_fuse_mlp_allreduce
|
||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
|
||||
return self.forward_normal(
|
||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def forward_normal_dual_stream(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
@@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_normal(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
@@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
allow_reduce_scatter=True,
|
||||
)
|
||||
|
||||
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
||||
@@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and not self.is_nextn
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
|
||||
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||
forward_batch
|
||||
)
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
)
|
||||
|
||||
if can_fuse_mlp_allreduce:
|
||||
hidden_states._sglang_needs_allreduce_fusion = True
|
||||
|
||||
@@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module):
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
||||
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user