[long_seq] fix A2 accuracy problem (#4030)
### What this PR does / why we need it?
1、update prepare_finalize.py:fix A2 accuracy problem when pcp and dcp
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: LookAround <lixushi@huawei.com>
This commit is contained in:
@@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import enable_sp
|
||||
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
@@ -382,6 +385,17 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
@@ -431,6 +445,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||
dim=0)
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
@@ -504,6 +521,16 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
router_logits = self._naive_multicast(router_logits,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
@@ -528,6 +555,10 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
hidden_states) # Sum across DP
|
||||
hidden_states = hidden_states[start:end, :]
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||
dim=0)
|
||||
|
||||
if reduce_results and (self.moe_config.tp_size > 1
|
||||
or self.moe_config.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user