From 9f5ab59e307a66fd0b17916218c9387c328b1c59 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Tue, 27 May 2025 15:16:17 +0800 Subject: [PATCH] [WIP][BugFix]Fix accuracy issues caused by wrong etp_size passed into FusedMoEParallelConfig when using vLLM 0.9.0 (#961) ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/ops/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6313a75..bc3b86b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -748,6 +748,7 @@ class AscendFusedMoE(FusedMoE): vllm_parallel_config=vllm_config.parallel_config)) self.moe_parallel_config.ep_size = get_ep_group().world_size + self.moe_parallel_config.tp_size = get_etp_group().world_size self.top_k = top_k self.num_experts = num_experts