From 37a0715edade3efae975c02ec8edf6f7a5d07530 Mon Sep 17 00:00:00 2001 From: weichen <132029610+Pr0Wh1teGivee@users.noreply.github.com> Date: Mon, 22 Sep 2025 19:12:58 +0800 Subject: [PATCH] [Refactor] Adjustments to moe_comm_method selection process (#3001) ### What this PR does / why we need it? Fix issues mentioned in https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor refactoring. 1. Use Enum instead of string. 2. Avoid setting a new property to forward_context in AscendFusedMoE.forward(). 3. Enabling TokenDispatcherWithMoge. 4. Remove redundant code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 2. Aclgraph & eager - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 Signed-off-by: Pr0Wh1teGivee Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com> --- tests/ut/ops/test_common_fused_moe.py | 51 +-------- tests/ut/ops/test_fused_ops.py | 31 ++++-- tests/ut/ops/test_moe_comm_method.py | 28 ++++- tests/ut/quantization/test_w4a8_dynamic.py | 11 +- tests/ut/worker/test_model_runner_v1.py | 19 ++-- vllm_ascend/ascend_forward_context.py | 15 ++- vllm_ascend/ops/common_fused_moe.py | 121 ++------------------ vllm_ascend/ops/fused_moe.py | 31 +----- vllm_ascend/ops/moe/moe_comm_method.py | 123 ++++++--------------- vllm_ascend/ops/moe/moe_mlp.py | 3 +- vllm_ascend/ops/moe/token_dispatcher.py | 20 +--- vllm_ascend/spec_decode/eagle_proposer.py | 12 +- vllm_ascend/spec_decode/mtp_proposer.py | 8 +- vllm_ascend/worker/model_runner_v1.py | 48 ++++---- 14 files changed, 170 insertions(+), 351 deletions(-) diff --git a/tests/ut/ops/test_common_fused_moe.py b/tests/ut/ops/test_common_fused_moe.py index 2c678e0..6153a4e 100644 --- a/tests/ut/ops/test_common_fused_moe.py +++ b/tests/ut/ops/test_common_fused_moe.py @@ -17,56 +17,7 @@ from unittest.mock import patch import torch from tests.ut.base import TestBase -from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge - - -class TestFusedExpertsMoGE(TestBase): - - def test_fused_experts_moge(self): - with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \ - patch('torch_npu.npu_swiglu') as mock_swiglu, \ - patch('vllm_ascend.utils.is_310p') as mock_is_310p: - - mock_is_310p.return_value = False - - mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [ - torch.randn(x[0].shape[0], weight[0].shape[1]) - ] - - mock_swiglu.side_effect = lambda x: x - - hidden_states = torch.randn(4, 128) - w1 = torch.randn(4, 256, 128) - w2 = torch.randn(4, 128, 128) - topk_weights = torch.rand(4, 1) - topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long) - top_k = 1 - global_num_experts = 4 - - moe_parallel_config = type( - 'MockConfig', (), { - 'ep_size': 1, - 'tp_size': 1, - 'dp_size': 1, - 'tp_rank': 0, - 'dp_rank': 0, - 'ep_rank': 0, - 'use_ep': True - })() - - output = fused_experts_moge( - hidden_states=hidden_states, - w1=w1, - w2=w2, - moe_parallel_config=moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - apply_router_weight_on_input=True, - ) - - self.assertEqual(output.shape, (4, 128)) +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE class TestLoadWeight(TestBase): diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 001022a..a91fe5b 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -23,6 +23,7 @@ from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from tests.ut.base import TestBase +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) from vllm_ascend.ops.moe.experts_selector import select_experts @@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format): return weight_data +@pytest.fixture(autouse=True) +def setup_vllm_config_mock(mocker: MockerFixture): + mock_hf_config = MagicMock() + mock_hf_config.model_type = "llama" + + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2) + mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) + mock_vllm_config.model_config.max_model_len = 2048 + + mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', + return_value=mock_vllm_config) + mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', + return_value=mock_vllm_config) + + @pytest.fixture def mock_dist_env(mocker: MockerFixture): mock_moe_comm_method = MagicMock() @@ -74,7 +95,7 @@ def mock_dist_env(mocker: MockerFixture): mock_forward_context_obj = MagicMock( moe_comm_method=mock_moe_comm_method, - moe_comm_method_name="mc2commimpl", + moe_comm_type=MoECommType.MC2, max_tokens_across_dp=10, dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), mc2_mask=torch.zeros(16, dtype=torch.bool), @@ -104,12 +125,6 @@ def mock_dist_env(mocker: MockerFixture): return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', return_value=mock_forward_context_obj), \ - patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', - return_value=MagicMock( - parallel_config=MagicMock(tensor_parallel_size=2), - scheduler_config=MagicMock(max_num_seqs=4), - model_config=MagicMock(max_model_len=2048) - )), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', return_value=mock_forward_context_obj), \ @@ -501,7 +516,7 @@ class TestUnifiedApplyMLP(TestBase): mock_get_forward_context): mock_forward_context = MagicMock() - mock_forward_context.moe_comm_method_name = "mc2commimpl" + mock_forward_context.moe_comm_type = MoECommType.MC2 mock_get_forward_context.return_value = mock_forward_context mock_is_310p.return_value = False diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 4bd0e10..97aea93 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase): self.moe_config.dp_group = MagicMock() self.moe_config.num_global_redundant_experts = 0 + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" @@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase): @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") def test_all_gather_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context): + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "all_gather" @@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase): comm_impl.finalize(h_out, reduce_results=True) mock_pf_instance.finalize.assert_called_once_with(h_out, True) + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2" ) @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2") def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context): + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "mc2" @@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase): comm_impl.finalize(h_out, reduce_results=True) mock_pf_instance.finalize.assert_called_once_with(h_out, True) + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All" @@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase): @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV") def test_alltoall_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context): + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "alltoall" @@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase): mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, False, None) + @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch( "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" @@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase): @patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp") def test_fused_experts_method(self, mock_unified_apply_mlp, mock_token_dispatcher, mock_prepare_finalize, - mock_get_forward_context): + mock_get_forward_context, + mock_get_current_vllm_config): + # Mock vLLM config + mock_get_current_vllm_config.return_value = MagicMock() + # Mock forward context mock_context = MagicMock() mock_context.moe_comm_method = "all_gather" diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 70256af..d12bbe1 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): output_size = 56 group_size = 2 + @patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config') @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group') @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group') @patch('torch.distributed.get_rank', return_value=0) def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group, - get_current_vllm_config): + get_current_vllm_config, mock_get_ascend_config): + # Mock ascend config + mock_ascend_config = Mock() + mock_ascend_config.dynamic_eplb = False + mock_get_ascend_config.return_value = mock_ascend_config + mock_vllm_config = Mock() mock_vllm_config.quant_config = Mock(quant_description={ "group_size": self.group_size, "version": "0.0.0" }) mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) + mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048, + max_model_len=2048, + enable_chunked_prefill=False) get_current_vllm_config.return_value = mock_vllm_config self.quant_method = AscendW4A8DynamicFusedMoEMethod() diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 9c116de..70b7c7d 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch import pytest +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import AscendSocVersion from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -24,21 +25,21 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", [ # Case 1: Expert parallel is disabled, should always be 'allgather' - (AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"), - (AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"), + (AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER), + (AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER), # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2 - (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"), - (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"), - (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition + (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather - (AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"), - (AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"), + (AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER), + (AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER), # Case 4: A3 SOC - (AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"), - (AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"), + (AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2), + (AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL), ]) # yapf: enable def test_select_moe_comm_method(soc_version, enable_expert_parallel, diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a8cbf83..e5ce07f 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -22,6 +22,13 @@ class FusedMoEState(Enum): All2AllSeq = 5 +class MoECommType(Enum): + ALLGATHER = 0 + MC2 = 1 + ALLTOALL = 2 + NAIVE_MULTICAST = 3 + + # TODO(zzzzwwjj): add soc_version to choose branch def _get_fused_moe_state(ep_size: int, with_prefill: bool, is_deepseek_v3_r1: bool): @@ -52,7 +59,7 @@ def set_ascend_forward_context( with_prefill: bool = True, in_profile_run: bool = False, reserved_mc2_mask: Optional[torch.Tensor] = None, - moe_comm_method: str = "", + moe_comm_type: Optional[MoECommType] = None, num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, @@ -72,7 +79,11 @@ def set_ascend_forward_context( batch_descriptor=batch_descriptor, ): forward_context = get_forward_context() - forward_context.moe_comm_method_name = moe_comm_method + "commimpl" + + from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method + forward_context.moe_comm_type = moe_comm_type + forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) + forward_context.with_prefill = with_prefill tp_world_size = get_tensor_model_parallel_world_size() ep_size = (get_ep_group().world_size if diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index ff301bd..57beae2 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -23,106 +23,23 @@ from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, MC2CommImpl, - NaiveMulticastCommImpl) +from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ -def fused_experts_moge( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - moe_parallel_config: FusedMoEParallelConfig, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - ep_size = moe_parallel_config.ep_size - local_num_experts = global_num_experts // ep_size - local_num_group = top_k // ep_size - - bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - sorted_topk_ids = sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, sorted_topk_ids // local_num_group) - - experts_id = torch.arange(0, - local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - - gate_up_out = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - if is_310p(): - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out *= topk_scales - - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) - unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - bsz, top_k // ep_size, -1).sum(1) - - return final_hidden_states - - def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) @@ -178,20 +95,6 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) - if topk_ids.shape[1] < top_k or is_310p(): - assert global_num_experts is not None - return fused_experts_moge( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -277,13 +180,7 @@ class AscendFusedMoE(FusedMoE): if self.dynamic_eplb: self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) - for method in { - AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, - NaiveMulticastCommImpl - }: - setattr( - self, method.__name__.lower(), - method(moe_config=self.moe_config)) # type: ignore[abstract] + setup_moe_comm_method(self.moe_config) def update_expert_map(self, new_expert_map): self.expert_map = new_expert_map @@ -307,8 +204,8 @@ class AscendFusedMoE(FusedMoE): outputs since each rank only has partial outputs. """ forward_context = get_forward_context() - moe_comm_method_name = forward_context.moe_comm_method_name - if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -318,10 +215,6 @@ class AscendFusedMoE(FusedMoE): assert self.quant_method is not None forward_context = get_forward_context() - moe_comm_method_name = forward_context.moe_comm_method_name - - forward_context.moe_comm_method = getattr(self, moe_comm_method_name) - hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits) @@ -449,8 +342,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` forward_context = get_forward_context() - moe_comm_method_name = forward_context.moe_comm_method_name - if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: shared_out = tensor_model_parallel_all_reduce(shared_out) _, fused_out = AscendFusedMoE.forward( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0c1526d..4fd4b1b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -41,9 +41,7 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, MC2CommImpl, - NaiveMulticastCommImpl) +from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, get_all_reduce_merge_state, @@ -339,13 +337,7 @@ class AscendFusedMoE(FusedMoE): self.moe_config.mc2_group = get_mc2_group() self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num - for method in { - AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, - NaiveMulticastCommImpl - }: - setattr( - self, method.__name__.lower(), - method(moe_config=self.moe_config)) # type: ignore[abstract] + setup_moe_comm_method(self.moe_config) def update_expert_map(self, new_expert_map): self.expert_map = new_expert_map @@ -360,22 +352,6 @@ class AscendFusedMoE(FusedMoE): if self.moe_load is not None: self.moe_load.zero_() - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(self.dp_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -412,9 +388,6 @@ class AscendFusedMoE(FusedMoE): mc2_mask = chunk_mc2_mask[tp_rank] replace_allreduce = True - moe_comm_method_name = forward_context.moe_comm_method_name - forward_context.moe_comm_method = getattr(self, moe_comm_method_name) - hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index e4082ba..555189e 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Dict, Optional import torch +from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( FusedMoEPrepareAndFinalizeWithAll2All, FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, @@ -28,13 +31,31 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2) + TokenDispatcherWithMC2, + TokenDispatcherWithMoge) + +_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} + + +def get_moe_comm_method( + moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]: + return _MoECommMethods.get(moe_comm_type) + + +def setup_moe_comm_method(moe_config): + _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) + _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) + _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( + moe_config) class MoECommMethod(ABC): """Base class for MoE communication methods.""" def __init__(self, moe_config: FusedMoEConfig): + self.model_type = get_current_vllm_config( + ).model_config.hf_config.model_type self.moe_config = moe_config self.mc2_mask = None @@ -113,8 +134,8 @@ class MoECommMethod(ABC): apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"] + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, w1=w1, @@ -126,6 +147,7 @@ class MoECommMethod(ABC): group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, with_quant=use_int8_w8a8 or use_int4_w4a8, fusion=use_int8_w8a8, @@ -170,94 +192,21 @@ class AllGatherCommImpl(MoECommMethod): """ def _get_token_dispatcher(self): - return TokenDispatcherWithAllGather( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + if self.model_type == "PanguProMoE": + return TokenDispatcherWithMoge( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + else: + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) def _get_fused_moe_prepare_finalize(self): return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) -class NativeAllGatherCommImpl(AllGatherCommImpl): - """This implementation should be compatible with all scenarios. - - Note that this implementation purely consists of native PyTorch ops - and does not use any NPU-specific ops. So the performance may not be optimal. - But it is a good fallback for scenarios where NPU-specific ops are not available. - """ - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - # Generate token indices and flatten - token_indices = torch.arange(num_tokens, - device=hidden_states.device, - dtype=torch.int64) - token_indices = (token_indices.unsqueeze(1).expand( - -1, self.moe_config.experts_per_token).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = (expert_map[experts_flat] - if expert_map is not None else experts_flat) - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - filtered_weights = torch.where(mask, weights_flat, - torch.zeros_like(weights_flat)).to( - topk_weights.dtype) - filtered_experts = torch.where( - mask, - local_experts_flat, - torch.full_like(local_experts_flat, num_experts), - ).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=hidden_states.device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - - # Rearrange hidden_states - permuted_hidden_states = hidden_states[self.sorted_token_indices] - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros_like(hidden_states) - final_hidden_states.index_add_(0, self.sorted_token_indices, - mlp_output) - - hidden_states[:] = final_hidden_states - - class MC2CommImpl(MoECommMethod): """This implementation is for the scenarios listed below: 1. `enable_expert_parallel=True`. diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index b1567b0..3cc8b95 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -21,6 +21,7 @@ import torch_npu from torch.nn.functional import pad from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import dispose_tensor, is_310p @@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias1, bias2 = None, None _output_dtype = w2_scale.dtype - is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl" + is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: if w1_scale.dtype != torch.float32: w1_scale = w1_scale.to(torch.float32) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index c57cc1c..b6f908e 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -377,14 +377,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): # mypy: disable-error-code="override" -class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): +class TokenDispatcherWithMoge(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_router_weight_on_input = False - self.local_ep = 1 - self.local_num_experts = self.num_experts // self.local_ep - self.local_num_group = self.top_k // self.local_ep + self.local_num_experts = self.num_experts // self.ep_size + self.local_num_group = self.top_k // self.ep_size self.bsz = None def token_dispatch(self, @@ -401,17 +400,6 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): - self.apply_router_weight_on_input = apply_router_weight_on_input - if self.apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * \ - topk_weights.to(hidden_states.dtype) - self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -445,7 +433,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): unsorted_hidden_states = hidden_states.index_select( 0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( - self.bsz, self.top_k // self.local_ep, -1).sum(1) + self.bsz, self.top_k // self.ep_size, -1).sum(1) return final_hidden_states diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 9184bde..f993e3a 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -117,11 +117,11 @@ class EagleProposer(Proposer): skip_attn: bool = False, num_reqs: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None): - moe_comm_method = self.runner._select_moe_comm_method( + moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) with set_ascend_forward_context(None, self.vllm_config, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], @@ -454,7 +454,7 @@ class EagleProposer(Proposer): with_prefill = attn_metadata.attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] - moe_comm_method = self.runner._select_moe_comm_method( + moe_comm_type = self.runner._select_moe_comm_method( num_input_tokens, with_prefill) # copy inputs to buffer for cudagraph @@ -463,7 +463,7 @@ class EagleProposer(Proposer): attn_metadata.block_tables = block_table.to(device) with set_ascend_forward_context(attn_metadata, self.vllm_config, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -495,7 +495,7 @@ class EagleProposer(Proposer): else: input_batch_size = batch_size - moe_comm_method = self.runner._select_moe_comm_method( + moe_comm_type = self.runner._select_moe_comm_method( input_batch_size, False) attn_metadata.num_actual_tokens = batch_size @@ -568,7 +568,7 @@ class EagleProposer(Proposer): # Run the model. with set_ascend_forward_context(attn_metadata, self.vllm_config, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 800b57f..5694c23 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -113,7 +113,7 @@ class MtpProposer(Proposer): _) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill, False) - moe_comm_method = self.runner._select_moe_comm_method( + moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) is_running_torchair = self.torchair_graph_enabled and \ @@ -146,7 +146,7 @@ class MtpProposer(Proposer): with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, reserved_mc2_mask=self.runner.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, in_profile_run=self.runner.in_profile_run, num_actual_tokens=0): if is_running_torchair: @@ -425,7 +425,7 @@ class MtpProposer(Proposer): num_tokens_across_dp = self.runner.num_tokens_across_dp with_prefill = self.runner.with_prefill - moe_comm_method = self.runner._select_moe_comm_method( + moe_comm_type = self.runner._select_moe_comm_method( num_input_tokens, with_prefill) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) @@ -440,7 +440,7 @@ class MtpProposer(Proposer): with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, reserved_mc2_mask=self.runner.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e096fd6..a7cdd26 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -94,7 +94,8 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.ascend_forward_context import (MoECommType, + set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata @@ -1860,7 +1861,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) def _select_moe_comm_method(self, num_tokens: int, - with_prefill: bool) -> str: + with_prefill: bool) -> MoECommType: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all are designed for expert parallelism. 2. If expert parallel is enabled, we need to consider the soc version and the @@ -1881,36 +1882,44 @@ class NPUModelRunner(LoRAModelRunnerMixin): ValueError: If the soc version is unsupported. Returns: - str: The selected MoE communication method, either "allgather", "mc2", or "alltoall". + MoECommType: The selected MoE communication method. """ soc_version = get_ascend_soc_version() quant_type = getattr(self.vllm_config.model_config.hf_config, 'moe_quantize', None) + model_type = self.vllm_config.model_config.hf_config.model_type if not self.parallel_config.enable_expert_parallel: - moe_comm_method = "allgather" + moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendSocVersion.A2}: - if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16: - moe_comm_method = "mc2" + if (num_tokens <= self.mc2_tokens_capacity + and self.parallel_config.world_size_across_dp >= 16): + moe_comm_type = MoECommType.MC2 else: + # Currently, w4a8_dynamic does not support allgatherep if quant_type == "w4a8_dynamic": - moe_comm_method = "alltoall" + moe_comm_type = MoECommType.ALLTOALL else: - moe_comm_method = "allgather" + moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendSocVersion.A3}: - moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall" + moe_comm_type = (MoECommType.MC2 + if num_tokens <= self.mc2_tokens_capacity else + MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}") - if moe_comm_method == "allgather" and with_prefill: - moe_comm_method = "naivemulticast" + if moe_comm_type == MoECommType.ALLGATHER and with_prefill: + moe_comm_type = MoECommType.NAIVE_MULTICAST + + # PanguProMoE only supports allgather + if model_type == "PanguProMoE": + moe_comm_type = MoECommType.ALLGATHER if is_global_first_rank(): logger.debug(f"num_tokens: {num_tokens}, " - f"moe_comm_method: {moe_comm_method}") - - return moe_comm_method + f"moe_comm_type: {moe_comm_type}") + return moe_comm_type @torch.inference_mode() def execute_model( @@ -1942,8 +1951,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() - moe_comm_method = self._select_moe_comm_method(num_input_tokens, - self.with_prefill) + moe_comm_type = self._select_moe_comm_method(num_input_tokens, + self.with_prefill) uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( scheduler_output.total_num_scheduled_tokens @@ -1962,7 +1971,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, with_prefill=self.with_prefill, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. @@ -2351,8 +2360,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) - moe_comm_method = self._select_moe_comm_method(num_tokens, - with_prefill) + moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using @@ -2472,7 +2480,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill=with_prefill, in_profile_run=self.in_profile_run, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method, + moe_comm_type=moe_comm_type, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor,