[BugFix] Fix aclgraph accu problem in A2. (#3163)

This PR fixes accuracy problem of aclgraph on A2. The problem is
introduced by PR #2980, which makes the `all_reduce` of shared_experts
exposed to torch dynamo. This PR moves all the codes into forward_impl
to shiled from torch dynamo.

- vLLM version: v0.10.2
- vLLM main:
17b4c6685c

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-09-28 21:31:55 +08:00
committed by GitHub
parent c3fee66806
commit 14d4ed5f0c
2 changed files with 30 additions and 19 deletions

View File

@@ -10,6 +10,7 @@ from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
def _maybe_chunk_residual_impl(x: torch.Tensor,
@@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
@@ -182,3 +193,9 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")