From ffe51eedd6e609cd1bcd2ee595eda49f0e271f0a Mon Sep 17 00:00:00 2001 From: weichen Date: Tue, 23 Dec 2025 18:53:48 +0800 Subject: [PATCH] [Refactor][MoE] Reuse vLLM's all_reduce logic (#5189) ### What this PR does / why we need it? Move all_reduce logic to AscendFusedMoE.forward, reuse vLLM's logic. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: weichen Co-authored-by: weijinqian0 <1184188277@qq.com> --- tests/ut/ops/test_prepare_finalize.py | 7 +---- vllm_ascend/ops/fused_moe/fused_moe.py | 27 ------------------- vllm_ascend/ops/fused_moe/prepare_finalize.py | 5 ---- 3 files changed, 1 insertion(+), 38 deletions(-) diff --git a/tests/ut/ops/test_prepare_finalize.py b/tests/ut/ops/test_prepare_finalize.py index fe2932a9..bb867155 100644 --- a/tests/ut/ops/test_prepare_finalize.py +++ b/tests/ut/ops/test_prepare_finalize.py @@ -169,15 +169,12 @@ class TestPrepareAndFinalize(unittest.TestCase): self.assertEqual(final_result.shape[0], 2) @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group") - @patch( - "vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce" - ) @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") @patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp", return_value=False) def test_allgather_prepare_finalize(self, mock_enable_sp, mock_get_forward_context, - mock_tp_all_reduce, mock_get_dp_group): + mock_get_dp_group): # Mock forward context mock_context = MagicMock() mock_context.max_tokens_across_dp = 6 @@ -222,7 +219,5 @@ class TestPrepareAndFinalize(unittest.TestCase): self.assertEqual(result.shape[0], 3) - # Test with TP all-reduce - mock_tp_all_reduce.return_value = result result_with_tp = layer.finalize(h_out, reduce_results=True) self.assertEqual(result_with_tp.shape[0], 3) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index b4cbbb48..a9547a5a 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -18,7 +18,6 @@ import os.path from typing import Any, Callable, Optional import torch -import torch.nn.functional as F from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) @@ -292,32 +291,6 @@ class AscendFusedMoE(FusedMoE): return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( final_hidden_states) - def forward( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - og_hidden_states = hidden_states.shape[-1] - if self.hidden_size != og_hidden_states: - hidden_states = F.pad( - hidden_states, - (0, self.hidden_size - og_hidden_states), - mode="constant", - value=0.0, - ) - if self.shared_experts is None: - fused_output = torch.ops.vllm.moe_forward(hidden_states, - router_logits, - self.layer_name) - return fused_output[..., :og_hidden_states] - else: - shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name) - return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], - ) - def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 2e7db621..05f43912 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -22,7 +22,6 @@ import torch import torch.distributed as dist import torch.nn as nn import torch_npu -from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_dp_group, get_pcp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -470,8 +469,4 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0) - if reduce_results and (self.moe_config.tp_size > 1 - or self.moe_config.ep_size > 1): - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - return hidden_states