Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
tarinkk
2025-03-27 20:09:35 -04:00
committed by GitHub
parent 98a2cfa9b2
commit 7f19e083c1
10 changed files with 238 additions and 47 deletions

View File

@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if forward_mode is not None and not forward_mode.is_idle():
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if self.n_shared_experts is not None:
@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn: bool = False,
prefix: str = "",
) -> None:
def is_sparse_layer(l: int):
return (
config.n_routed_experts is not None
and l >= config.first_k_dense_replace
and l % config.moe_layer_freq == 0
)
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
if is_nextn or is_sparse_layer(layer_id):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = True
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = False
self.input_is_scattered = (
is_sparse_layer(layer_id - 1)
and global_server_args_dict["enable_deepep_moe"]
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
return self.forward_deepep(
positions, hidden_states, forward_batch, residual
)
else:
return self.forward_normal(
positions, hidden_states, forward_batch, residual
)
def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if global_server_args_dict["enable_deepep_moe"] and isinstance(
self.mlp, DeepseekV2MoE
):
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
return hidden_states, residual
else:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected
hidden_states = self.mlp(hidden_states)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
def forward_deepep(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
return hidden_states, residual
class DeepseekV2Model(nn.Module):