Use reduce scatter for DP (#8539)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
attn_tp_all_gather_into_tensor,
|
attn_tp_all_gather_into_tensor,
|
||||||
attn_tp_reduce_scatter_tensor,
|
attn_tp_reduce_scatter_tensor,
|
||||||
dp_gather_partial,
|
dp_gather_partial,
|
||||||
|
dp_reduce_scatter_tensor,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
get_attention_dp_size,
|
get_attention_dp_size,
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
|||||||
layer_scatter_modes: LayerScatterModes,
|
layer_scatter_modes: LayerScatterModes,
|
||||||
input_layernorm: torch.nn.Module,
|
input_layernorm: torch.nn.Module,
|
||||||
post_attention_layernorm: torch.nn.Module,
|
post_attention_layernorm: torch.nn.Module,
|
||||||
|
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
||||||
|
allow_reduce_scatter: bool = False,
|
||||||
):
|
):
|
||||||
self.layer_scatter_modes = layer_scatter_modes
|
self.layer_scatter_modes = layer_scatter_modes
|
||||||
self.input_layernorm = input_layernorm
|
self.input_layernorm = input_layernorm
|
||||||
self.post_attention_layernorm = post_attention_layernorm
|
self.post_attention_layernorm = post_attention_layernorm
|
||||||
|
self.allow_reduce_scatter = allow_reduce_scatter
|
||||||
|
|
||||||
self._context = CommunicateContext.init_new()
|
self._context = CommunicateContext.init_new()
|
||||||
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
||||||
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
|||||||
residual=residual,
|
residual=residual,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
context=self._context,
|
context=self._context,
|
||||||
|
allow_reduce_scatter=self.allow_reduce_scatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
||||||
|
return (
|
||||||
|
self.allow_reduce_scatter
|
||||||
|
and self._communicate_summable_tensor_pair_fn
|
||||||
|
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
||||||
|
and forward_batch.dp_padding_mode.is_max_len()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
context: CommunicateContext,
|
context: CommunicateContext,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
context: CommunicateContext,
|
context: CommunicateContext,
|
||||||
|
allow_reduce_scatter: bool = False,
|
||||||
):
|
):
|
||||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
|
||||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
|
||||||
# be careful about this!
|
|
||||||
hidden_states, global_hidden_states = (
|
hidden_states, global_hidden_states = (
|
||||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||||
|
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||||
|
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||||
|
else:
|
||||||
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
context: CommunicateContext,
|
context: CommunicateContext,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = None
|
residual = None
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
GroupCoordinator,
|
GroupCoordinator,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
@@ -355,6 +356,17 @@ def dp_scatter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||||
|
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
||||||
|
get_tp_group().reduce_scatter_tensor(output, input)
|
||||||
|
else:
|
||||||
|
scattered_local_tokens = input.tensor_split(
|
||||||
|
get_tensor_model_parallel_world_size()
|
||||||
|
)[get_tensor_model_parallel_rank()]
|
||||||
|
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
||||||
|
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
||||||
|
|
||||||
|
|
||||||
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||||
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
||||||
|
|
||||||
|
|||||||
@@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
# It does not support additional parameters.
|
# It does not support additional parameters.
|
||||||
param.load_row_parallel_weight(loaded_weight)
|
param.load_row_parallel_weight(loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_, can_fuse_mlp_allreduce=False):
|
def forward(self, input_, skip_all_reduce=False):
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
@@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
sm.tag(output_parallel)
|
sm.tag(output_parallel)
|
||||||
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|||||||
@@ -628,8 +628,10 @@ class ForwardBatch:
|
|||||||
self.dp_padding_mode = dp_padding_mode
|
self.dp_padding_mode = dp_padding_mode
|
||||||
|
|
||||||
if dp_padding_mode.is_max_len():
|
if dp_padding_mode.is_max_len():
|
||||||
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
|
# when DP gather mode is all gather, we will use
|
||||||
# where transferred tokens should be padded to the same length.
|
# all_gather_into_tensor to gather hidden states, where transferred
|
||||||
|
# tokens should be padded to the same length. We will also use
|
||||||
|
# reduce-scatter instead of all-reduce after MLP.
|
||||||
max_num_tokens = max(global_num_tokens)
|
max_num_tokens = max(global_num_tokens)
|
||||||
global_num_tokens = [max_num_tokens] * sync_group_size
|
global_num_tokens = [max_num_tokens] * sync_group_size
|
||||||
buffer_len = max_num_tokens * sync_group_size
|
buffer_len = max_num_tokens * sync_group_size
|
||||||
|
|||||||
@@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
forward_batch=None,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
|
):
|
||||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
x, _ = self.down_proj(
|
||||||
|
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: Optional[ForwardBatch] = None,
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not self._enable_deepep_moe:
|
if not self._enable_deepep_moe:
|
||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
@@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
):
|
):
|
||||||
return self.forward_normal_dual_stream(
|
return self.forward_normal_dual_stream(
|
||||||
hidden_states, can_fuse_mlp_allreduce
|
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
|
return self.forward_normal(
|
||||||
|
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
@@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
final_hidden_states = final_hidden_states_out
|
final_hidden_states = final_hidden_states_out
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
@@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
final_hidden_states = final_hidden_states_out
|
final_hidden_states = final_hidden_states_out
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
||||||
@@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
and not self.is_nextn
|
and not self.is_nextn
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
hidden_states = self.mlp(
|
||||||
|
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||||
|
)
|
||||||
|
|
||||||
if can_fuse_mlp_allreduce:
|
if can_fuse_mlp_allreduce:
|
||||||
hidden_states._sglang_needs_allreduce_fusion = True
|
hidden_states._sglang_needs_allreduce_fusion = True
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module):
|
|||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user