[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)

### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`

#### How to run

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

##### a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

##### b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
9edd1db02b

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Wang Kunpeng
2025-08-12 14:12:12 +08:00
committed by GitHub
parent 81817908ca
commit dc585f148a
6 changed files with 169 additions and 37 deletions

View File

@@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
@@ -160,7 +161,13 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
num_tokens = output_parallel.shape[0]
if is_force_scatter and num_tokens % self.tp_size:
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
if is_force_scatter or (not is_prefill
and output_parallel.shape[0] % self.tp_size
== 0):
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
@@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
@@ -347,13 +355,15 @@ class CustomDeepseekV2MoE(nn.Module):
reduce_results = not self.all_reduce_merge
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.shared_experts = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe,
force_replicate=self.enable_multistream_moe
or enable_shared_expert_dp,
prefix=f"{prefix}.shared_experts",
)
else:
@@ -447,9 +457,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
@@ -462,6 +474,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
@@ -501,8 +514,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
prefix=f"{prefix}.kv_b_proj")
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and
ascend_config.torchair_graph_config.enable_multistream_moe):
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe
or self.enable_shared_expert_dp)):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
@@ -596,13 +610,27 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
output = output.view(-1, output_shape[-1])
return output
else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
hidden_states_or_q_c = get_tp_group().all_gather(
hidden_states_or_q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
kv_c, k_pe = kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
num_tokens = hidden_states_or_q_c.shape[0]
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
output_shape=output_shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -677,6 +705,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
def forward(
self,
@@ -731,6 +761,18 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# first layer.
residual *= 1. / self.routed_scaling_factor
tp_size = get_tensor_model_parallel_world_size()
if self.enable_shared_expert_dp and (
self.layer_idx == self.first_k_dense_replace
or self.layer_idx == self.layers) and tp_size > 1:
num_tokens, _ = residual.shape
if num_tokens % tp_size:
residual = nn.functional.pad(residual,
(0, 0, 0, -num_tokens % tp_size))
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
residual = chunk_residual[tp_rank]
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
@@ -756,6 +798,22 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)
# for last layer of main model and mtp layer.
if self.enable_shared_expert_dp and self.layer_idx >= (
self.layers - 1) and tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
residual = get_tp_group().all_gather(residual, 0)
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
num_tokens = hidden_states.shape[0]
if num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:num_tokens]
residual = residual[:num_tokens]
return hidden_states, residual