[BugFix] Fix aclgraph accu problem in A2. (#3163)
This PR fixes accuracy problem of aclgraph on A2. The problem is
introduced by PR #2980, which makes the `all_reduce` of shared_experts
exposed to torch dynamo. This PR moves all the codes into forward_impl
to shiled from torch dynamo.
- vLLM version: v0.10.2
- vLLM main:
17b4c6685c
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -203,12 +203,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
||||||
outputs since each rank only has partial outputs.
|
outputs since each rank only has partial outputs.
|
||||||
"""
|
"""
|
||||||
forward_context = get_forward_context()
|
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
final_hidden_states)
|
||||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
|
||||||
return final_hidden_states
|
|
||||||
else:
|
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
||||||
|
|
||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
@@ -333,6 +329,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
shared_out, fused_out = AscendFusedMoE.forward(
|
||||||
|
self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
return shared_out, fused_out
|
||||||
|
|
||||||
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
self.shared_expert_stream.wait_stream( # type: ignore
|
self.shared_expert_stream.wait_stream( # type: ignore
|
||||||
@@ -347,8 +352,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
|
fused_output = AscendFusedMoE.forward_impl(
|
||||||
_, fused_out = AscendFusedMoE.forward(
|
|
||||||
self,
|
self,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -356,17 +360,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
# Make sure the default stream waits for the shared experts stream to finish.
|
# Make sure the default stream waits for the shared experts stream to finish.
|
||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
||||||
return shared_out, fused_out
|
return shared_out, fused_output
|
||||||
|
|
||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor):
|
|
||||||
shared_output = torch.empty(1)
|
|
||||||
fused_output = AscendFusedMoE.forward_impl(
|
|
||||||
self,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
)
|
|
||||||
return shared_output, fused_output
|
|
||||||
|
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
|
|
||||||
|
|
||||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||||
@@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||||
|
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||||
|
return final_hidden_states
|
||||||
|
else:
|
||||||
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||||
op_func=_maybe_chunk_residual_impl,
|
op_func=_maybe_chunk_residual_impl,
|
||||||
fake_impl=lambda x, residual: residual,
|
fake_impl=lambda x, residual: residual,
|
||||||
@@ -182,3 +193,9 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
|||||||
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
fake_impl=_maybe_wait_prefetch_done_impl_fake,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
dispatch_key="PrivateUse1")
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||||
|
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||||
|
fake_impl=lambda x: x,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
|
|||||||
Reference in New Issue
Block a user