diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index f158a4bf..f54d4579 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable + +if prefill_context_parallel_enable(): + from vllm.distributed import get_pcp_group class QuantType(Enum): @@ -382,6 +385,17 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): hidden_states, 0) router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) + + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + return hidden_states, router_logits, None, None def finalize(self, @@ -431,6 +445,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) hidden_states = hidden_states[:self.num_tokens] + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().reduce_scatter(hidden_states, + dim=0) if reduce_results and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states) @@ -504,6 +521,16 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): router_logits = self._naive_multicast(router_logits, self.cu_tokens_across_dp_cpu) + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + return hidden_states, router_logits, None, None def finalize(self, @@ -528,6 +555,10 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): hidden_states) # Sum across DP hidden_states = hidden_states[start:end, :] + if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().reduce_scatter(hidden_states, + dim=0) + if reduce_results and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states)