diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 554b40e..ac22b69 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -203,12 +203,8 @@ class AscendFusedMoE(FusedMoE): `finalize` function. In `allgathercommimpl`, we still need to all-reduce the outputs since each rank only has partial outputs. """ - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: - return final_hidden_states - else: - return tensor_model_parallel_all_reduce(final_hidden_states) + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -333,6 +329,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): # Make sure the shared experts stream begins after hidden_states are ready. if self.multistream_overlap_shared_expert: self.shared_expert_stream.wait_stream( # type: ignore @@ -347,8 +352,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: shared_out = tensor_model_parallel_all_reduce(shared_out) - - _, fused_out = AscendFusedMoE.forward( + fused_output = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, @@ -356,17 +360,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # Make sure the default stream waits for the shared experts stream to finish. if self.multistream_overlap_shared_expert: torch.npu.current_stream().wait_stream(self.shared_expert_stream) - return shared_out, fused_out - - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - shared_output = torch.empty(1) - fused_output = AscendFusedMoE.forward_impl( - self, - hidden_states=hidden_states, - router_logits=router_logits, - ) - return shared_output, fused_output + return shared_out, fused_output UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index a702b35..438bff1 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -10,6 +10,7 @@ from vllm.forward_context import get_forward_context from vllm.utils import direct_register_custom_op import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_forward_context import MoECommType def _maybe_chunk_residual_impl(x: torch.Tensor, @@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: return +def _maybe_all_reduce_tensor_model_parallel_impl( + final_hidden_states: torch.Tensor) -> torch.Tensor: + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, fake_impl=lambda x, residual: residual, @@ -182,3 +193,9 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done", fake_impl=_maybe_wait_prefetch_done_impl_fake, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", + op_func=_maybe_all_reduce_tensor_model_parallel_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1")