[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)

### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.

### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`

#### How to run

use parameter `--additional_config='{"enable_shared_expert_dp": true}'`

##### a.How to run eager mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'

##### b.How to run graph mode

eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
9edd1db02b

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Wang Kunpeng
2025-08-12 14:12:12 +08:00
committed by GitHub
parent 81817908ca
commit dc585f148a
6 changed files with 169 additions and 37 deletions

View File

@@ -621,6 +621,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
@@ -635,6 +636,8 @@ class AscendMLAImpl(MLAAttentionImpl):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
@@ -905,14 +908,7 @@ class AscendMLAImpl(MLAAttentionImpl):
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output, is_prefill=True)[0]
return attn_output
def exec_kv(
self,
@@ -1249,6 +1245,12 @@ class AscendMLAImpl(MLAAttentionImpl):
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
if not self.running_in_graph:
o_proj_input_shape = (num_actual_toks,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
@@ -1259,11 +1261,12 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
output[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
output[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode:
if self.running_in_graph:
@@ -1280,9 +1283,32 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
output[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
o_proj_input[:num_decode_tokens] = output_decode
else:
output[:num_decode_tokens] = output_decode
o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
output[...] = self.o_proj(
o_proj_input,
is_prefill=True,
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
return output_padded