Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
tp_all_gather,
|
||||
tp_reduce_scatter,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if forward_mode is not None and not forward_mode.is_idle():
|
||||
if (
|
||||
forward_mode is not None
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
is_nextn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
def is_sparse_layer(l: int):
|
||||
return (
|
||||
config.n_routed_experts is not None
|
||||
and l >= config.first_k_dense_replace
|
||||
and l % config.moe_layer_freq == 0
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
if is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
if is_nextn or is_sparse_layer(layer_id):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = True
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = False
|
||||
|
||||
self.input_is_scattered = (
|
||||
is_sparse_layer(layer_id - 1)
|
||||
and global_server_args_dict["enable_deepep_moe"]
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
||||
return self.forward_deepep(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if global_server_args_dict["enable_deepep_moe"] and isinstance(
|
||||
self.mlp, DeepseekV2MoE
|
||||
):
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
return hidden_states, residual
|
||||
else:
|
||||
if get_attention_tp_rank() == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_deepep(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user