[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)

### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`

#### How to run

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

##### a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

##### b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
9edd1db02b

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Wang Kunpeng
2025-08-12 14:12:12 +08:00
committed by GitHub
parent 81817908ca
commit dc585f148a
6 changed files with 169 additions and 37 deletions

View File

@@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
The details of each config option are as follows:

View File

@@ -691,3 +691,40 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -47,6 +47,9 @@ class AscendConfig:
self.expert_map_path = additional_config.get("expert_map_path", None)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", True
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
class TorchairGraphConfig:
@@ -166,6 +169,10 @@ def check_ascend_config(vllm_config, enforce_eager):
raise NotImplementedError(
"Torchair graph mode only works with following model types:"
f"{TORCHAIR_MODEL_LIST}.")
if ascend_config.enable_shared_expert_dp:
logger.warning(
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
"it has been disabled automatically.")
# aclgraph case
else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested.

View File

@@ -621,6 +621,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
@@ -635,6 +636,8 @@ class AscendMLAImpl(MLAAttentionImpl):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
@@ -905,14 +908,7 @@ class AscendMLAImpl(MLAAttentionImpl):
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output, is_prefill=True)[0]
return attn_output
def exec_kv(
self,
@@ -1249,6 +1245,12 @@ class AscendMLAImpl(MLAAttentionImpl):
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if not self.running_in_graph:
o_proj_input_shape = (num_actual_toks,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
@@ -1259,11 +1261,12 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
output[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode:
if self.running_in_graph:
@@ -1280,9 +1283,32 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
o_proj_input[:num_decode_tokens] = output_decode
else:
output[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
return output_padded

View File

@@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
@@ -160,7 +161,13 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
num_tokens = output_parallel.shape[0]
if is_force_scatter and num_tokens % self.tp_size:
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
if is_force_scatter or (not is_prefill
and output_parallel.shape[0] % self.tp_size
== 0):
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
@@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
@@ -347,13 +355,15 @@ class CustomDeepseekV2MoE(nn.Module):
reduce_results = not self.all_reduce_merge
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.shared_experts = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe,
force_replicate=self.enable_multistream_moe
or enable_shared_expert_dp,
prefix=f"{prefix}.shared_experts",
)
else:
@@ -447,9 +457,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
@@ -462,6 +474,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
@@ -501,8 +514,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
prefix=f"{prefix}.kv_b_proj")
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and
ascend_config.torchair_graph_config.enable_multistream_moe):
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe
or self.enable_shared_expert_dp)):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
@@ -596,13 +610,27 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
output = output.view(-1, output_shape[-1])
return output
else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
hidden_states_or_q_c = get_tp_group().all_gather(
hidden_states_or_q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
kv_c, k_pe = kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
num_tokens = hidden_states_or_q_c.shape[0]
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
output_shape=output_shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -677,6 +705,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
def forward(
self,
@@ -731,6 +761,18 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# first layer.
residual *= 1. / self.routed_scaling_factor
tp_size = get_tensor_model_parallel_world_size()
if self.enable_shared_expert_dp and (
self.layer_idx == self.first_k_dense_replace
or self.layer_idx == self.layers) and tp_size > 1:
num_tokens, _ = residual.shape
if num_tokens % tp_size:
residual = nn.functional.pad(residual,
(0, 0, 0, -num_tokens % tp_size))
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
residual = chunk_residual[tp_rank]
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
@@ -756,6 +798,22 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)
# for last layer of main model and mtp layer.
if self.enable_shared_expert_dp and self.layer_idx >= (
self.layers - 1) and tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
residual = get_tp_group().all_gather(residual, 0)
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
num_tokens = hidden_states.shape[0]
if num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:num_tokens]
residual = residual[:num_tokens]
return hidden_states, residual

View File

@@ -1268,6 +1268,7 @@ class AscendFusedMoE(FusedMoE):
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe and \
self.torchair_graph_enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
@@ -1408,22 +1409,24 @@ class AscendFusedMoE(FusedMoE):
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size:
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
@@ -1490,7 +1493,7 @@ class AscendFusedMoE(FusedMoE):
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
] and not replace_allreduce and not self.enable_shared_expert_dp):
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
@@ -1500,7 +1503,7 @@ class AscendFusedMoE(FusedMoE):
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
elif self.dp_size > 1:
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]