[long_seq] fix A2 accuracy problem (#4030)

### What this PR does / why we need it?
1、update prepare_finalize.py:fix A2 accuracy problem when pcp and dcp

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: LookAround <lixushi@huawei.com>
This commit is contained in:
LookAround0301
2025-11-07 09:29:33 +08:00
committed by GitHub
parent e0d58d543b
commit f8610b7d67

View File

@@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
if prefill_context_parallel_enable():
from vllm.distributed import get_pcp_group
class QuantType(Enum):
@@ -382,6 +385,17 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
return hidden_states, router_logits, None, None
def finalize(self,
@@ -431,6 +445,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
dim=0)
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
@@ -504,6 +521,16 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
router_logits = self._naive_multicast(router_logits,
self.cu_tokens_across_dp_cpu)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
return hidden_states, router_logits, None, None
def finalize(self,
@@ -528,6 +555,10 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
hidden_states) # Sum across DP
hidden_states = hidden_states[start:end, :]
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
dim=0)
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)