shared_experts+router_experts merge all_reduce(Improve TTOP 5ms) (#1395)

### What this PR does / why we need it?
When all_reduce_merge is in progress, shared_experts does not do
all_reduce in mlp, but waits until shared_experts+router_experts are
completed before doing all_reduce
In prefill and decode, as long as shared_experts+router_experts are
all_reduce, there will be benefits.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh
- vLLM version: v0.9.1
- vLLM main:
977180c912

---------

Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
ttanzhiqiang
2025-07-10 12:07:05 +08:00
committed by GitHub
parent 997f156a51
commit 60519c71bd
5 changed files with 32 additions and 7 deletions

View File

@@ -3,9 +3,10 @@ export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export ASCEND_LAUNCH_BLOCKING=0
export VLLM_VERSION=0.9.0
export VLLM_VERSION=0.9.1
nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--served-model-name auto \
--quantization ascend \
--trust-remote-code \
--distributed-executor-backend=mp \

View File

@@ -21,7 +21,8 @@ for concurrency in "${concurrency_array[@]}"; do
python /mnt/deepseek/vllm/benchmarks/benchmark_serving.py \
--backend vllm \
--trust-remote-code \
--model /mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--model auto \
--tokenizer /mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--dataset-name random \
--random-input-len 4096 \
--random-output-len 1536 \

View File

@@ -303,7 +303,6 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@@ -345,6 +344,8 @@ class CustomDeepseekV2MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias)
if config.n_shared_experts is not None:
self.all_reduce_merge = self.experts.all_reduce_merge
reduce_results = not self.all_reduce_merge
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = CustomDeepseekV2MLP(
@@ -352,7 +353,7 @@ class CustomDeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=True,
reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe,
prefix=f"{prefix}.shared_experts",
)
@@ -403,6 +404,9 @@ class CustomDeepseekV2MoE(nn.Module):
hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor +
experts_hidden_states[1])
if self.all_reduce_merge:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states

View File

@@ -44,8 +44,8 @@ from vllm_ascend.distributed.communication_op import \
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, is_310p, npu_stream_switch,
npu_wait_tensor)
get_all_reduce_merge_state, get_fused_moe_state,
is_310p, npu_stream_switch, npu_wait_tensor)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -1146,6 +1146,10 @@ class AscendFusedMoE(FusedMoE):
self.log2phy = None
self.global_redundant_expert_num = 0
is_deepseek_v3_r1 = self.global_num_experts == 256
self.all_reduce_merge = get_all_reduce_merge_state(
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
if expert_map_path and os.path.exists(expert_map_path):
@@ -1250,6 +1254,7 @@ class AscendFusedMoE(FusedMoE):
is_prefill, is_deepseek_v3_r1)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
@@ -1351,7 +1356,7 @@ class AscendFusedMoE(FusedMoE):
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and fused_moe_state in [
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
]:

View File

@@ -425,6 +425,20 @@ class FusedMoEState(Enum):
NaiveMulticast = 4
# TODO(ttanzhiqiang): all_reduce merge
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:
return True
return False
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):