[BugFix] Fix aclgraph accu problem in A2. (#3163)

This PR fixes accuracy problem of aclgraph on A2. The problem is
introduced by PR #2980, which makes the `all_reduce` of shared_experts
exposed to torch dynamo. This PR moves all the codes into forward_impl
to shiled from torch dynamo.

- vLLM version: v0.10.2
- vLLM main:
17b4c6685c

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-09-28 21:31:55 +08:00
committed by GitHub
parent c3fee66806
commit 14d4ed5f0c
2 changed files with 30 additions and 19 deletions

View File

@@ -203,12 +203,8 @@ class AscendFusedMoE(FusedMoE):
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
@@ -333,6 +329,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
@@ -347,8 +352,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
_, fused_out = AscendFusedMoE.forward(
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
@@ -356,17 +360,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_output = torch.empty(1)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_output, fused_output
return shared_out, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func

View File

@@ -10,6 +10,7 @@ from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
def _maybe_chunk_residual_impl(x: torch.Tensor,
@@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
@@ -182,3 +193,9 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")