From f8610b7d6702d21f9e932f8d1b4f16061da1806c Mon Sep 17 00:00:00 2001 From: LookAround0301 Date: Fri, 7 Nov 2025 09:29:33 +0800 Subject: [PATCH] [long_seq] fix A2 accuracy problem (#4030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? 1、update prepare_finalize.py:fix A2 accuracy problem when pcp and dcp - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: LookAround --- vllm_ascend/ops/fused_moe/prepare_finalize.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index f158a4bf..f54d4579 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable + +if prefill_context_parallel_enable(): + from vllm.distributed import get_pcp_group class QuantType(Enum): @@ -382,6 +385,17 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): hidden_states, 0) router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) + + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + return hidden_states, router_logits, None, None def finalize(self, @@ -431,6 +445,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) hidden_states = hidden_states[:self.num_tokens] + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().reduce_scatter(hidden_states, + dim=0) if reduce_results and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states) @@ -504,6 +521,16 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): router_logits = self._naive_multicast(router_logits, self.cu_tokens_across_dp_cpu) + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + return hidden_states, router_logits, None, None def finalize(self, @@ -528,6 +555,10 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): hidden_states) # Sum across DP hidden_states = hidden_states[start:end, :] + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().reduce_scatter(hidden_states, + dim=0) + if reduce_results and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states)