diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index c6eeb97f..76923930 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -23,8 +23,7 @@ from vllm.distributed.parallel_state import GroupCoordinator from vllm_ascend.models.deepseek_v2 import ( CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, - CustomDeepseekV2MLP, CustomDeepseekV2MoE, - CustomDeepseekV2RowParallelLinear, + CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead) @@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config): quant_config=None) -def test_custom_deepseek_v2_moe(mock_distributed, base_config, - mock_forward_context): - base_config.n_shared_experts = 1 - moe = CustomDeepseekV2MoE(config=base_config, - quant_config=None, - prefix="mlp") - assert moe.top_k == 2 - - x = torch.randn(2, 4, 128) - attn_metadata = Mock(num_prefills=1) - with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__", - return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))): - output = moe(x, attn_metadata) - assert output.shape == (2, 4, 128) - - @patch("torch_npu.npu_rms_norm") def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, base_config): diff --git a/tests/ut/ops/test_ascend_forwad_context.py b/tests/ut/ops/test_ascend_forwad_context.py deleted file mode 100644 index 17e3c6f3..00000000 --- a/tests/ut/ops/test_ascend_forwad_context.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -import unittest -from unittest import mock - -from vllm_ascend.ascend_forward_context import get_dispatcher_name - - -class TestGetDispatcherName(unittest.TestCase): - - def test_get_dispatcher_name(self): - result = get_dispatcher_name(1, False) - assert result == "TokenDispatcherWithAllGather" - result = get_dispatcher_name(4, False) - assert result == "TokenDispatcherWithAll2AllV" - result = get_dispatcher_name(16, True) - assert result == "TokenDispatcherWithAll2AllV" - result = get_dispatcher_name(16, False) - assert result == "TokenDispatcherWithMC2" - with mock.patch.dict(os.environ, - {"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}): - result = get_dispatcher_name(16, False) - assert result == "TokenDispatcherWithAllGather" diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py index f0c5ff82..a4a61a1a 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( FusedMoEPrepareAndFinalizeWithAll2All, - FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2) + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, + FusedMoEPrepareAndFinalizeWithNaiveMulticast) class TestFusedMoEPrepareAndFinalize(unittest.TestCase): @@ -216,3 +217,68 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_tp_all_reduce.return_value = result result_with_tp = layer.finalize(h_out, reduce_results=True) self.assertEqual(result_with_tp.shape[0], 3) + + @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" + ) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_naive_multicast_prepare_finalize(self, mock_get_forward_context, + mock_tp_all_reduce, + mock_get_dp_group): + # Mock forward context with DP metadata + mock_context = MagicMock() + mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor( + [2, 5, 7]) + mock_get_forward_context.return_value = mock_context + + # Setup DP group mock + mock_dp_group = MagicMock() + mock_dp_group.broadcast = MagicMock() + mock_dp_group.all_reduce = MagicMock() + mock_get_dp_group.return_value = mock_dp_group + + # Mock all_reduce to just return input (simulate sum) + def mock_all_reduce(tensor): + return tensor * 2 + + mock_dp_group.all_reduce.side_effect = mock_all_reduce + + # Setup config + self.moe_config.dp_size = 3 + self.moe_config.dp_rank = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + + layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) + + # Local inputs + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + # Mock gate for router logits recomputation + mock_gate = MagicMock() + mock_gate.return_value = (torch.randn(7, 2), None) + + # Run prepare + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + rm_router_logits=False, + gate=mock_gate) + + # Should be global tensor: [7, 8] and [7, 2] + self.assertEqual(h_out.shape, (7, 8)) + self.assertEqual(r_out.shape, (7, 2)) + + # Run finalize + result = layer.finalize(h_out, reduce_results=False) + + # Should slice back to local: [3, 8] + self.assertEqual(result.shape, (3, 8)) + + # Test with reduce_results=True and TP/EP > 1 + mock_tp_all_reduce.return_value = result + result_with_tp = layer.finalize(h_out, reduce_results=True) + self.assertEqual(result_with_tp.shape, (3, 8)) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 3e9351af..bd7ee587 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -22,10 +22,7 @@ import torch_npu from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module from tests.ut.base import TestBase -from vllm_ascend.ascend_forward_context import (FusedMoEState, - _get_fused_moe_state) from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) from vllm_ascend.ops.moe.experts_selector import select_experts @@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format): @pytest.fixture def mock_dist_env(mocker: MockerFixture): - mock_setup_token_dispatchers = MagicMock() - mock_token_dispatcher_with_allgather = MagicMock() - mock_token_dispatcher_with_all2allv = MagicMock() - mock_token_dispatcher_with_mc2 = MagicMock() + mock_moe_comm_method = MagicMock() - mock_dispatch_result_allgather = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([8, 16], dtype=torch.int64), - "group_list_type": 0, - } - mock_combine_result_allgather = torch.randn(16, 2) + def mock_prepare(hidden_states, router_logits, **kwargs): + return hidden_states, router_logits - mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather - mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather + mock_moe_comm_method.prepare.side_effect = mock_prepare - mock_dispatch_result_all2allv = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64), - "group_list_type": 1, - "dynamic_scale": None, - } - mock_combine_result_all2allv = torch.randn(16, 2) - mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv - mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv + mock_fused_experts_result = torch.randn(16, 2) + mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result - mock_dispatch_result_mc2 = { - "hidden_states": torch.randn(16, 2), - "group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64), - "group_list_type": 1, - "dynamic_scale": None, - "assist_info_for_combine": torch.randn(16, 2), - "ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32), - } - mock_combine_result_mc2 = torch.randn(16, 2) - mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2 - mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2 + def mock_finalize(hidden_states, **kwargs): + return hidden_states - captured_dispatchers = {} - - def capture_register(dispatcher_instance): - key = dispatcher_instance.__class__.__name__ - captured_dispatchers[key] = dispatcher_instance - if key == 'TokenDispatcherWithAllGather': - captured_dispatchers[key] = mock_token_dispatcher_with_allgather - elif key == 'TokenDispatcherWithAll2AllV': - captured_dispatchers[key] = mock_token_dispatcher_with_all2allv - elif key == 'TokenDispatcherWithMC2': - captured_dispatchers[key] = mock_token_dispatcher_with_mc2 - - mock_register_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher', - side_effect=capture_register) - - mock_get_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher', - side_effect=lambda name: captured_dispatchers.get(name)) - - default_mock_token_dispatcher = mock_token_dispatcher_with_allgather + mock_moe_comm_method.finalize.side_effect = mock_finalize mock_forward_context_obj = MagicMock( - fused_moe_state=FusedMoEState.AllGather, - token_dispatcher=default_mock_token_dispatcher, + moe_comm_method=mock_moe_comm_method, + moe_comm_method_name="mc2commimpl", max_tokens_across_dp=10, dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), mc2_mask=torch.zeros(16, dtype=torch.bool), @@ -131,14 +84,12 @@ def mock_dist_env(mocker: MockerFixture): with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('torch.distributed.all_gather'), \ - patch('torch.distributed.all_to_all_single'), \ - patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_ascend_config', @@ -150,6 +101,8 @@ def mock_dist_env(mocker: MockerFixture): return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ patch('vllm_ascend.ops.fused_moe.get_forward_context', return_value=mock_forward_context_obj), \ + patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', + return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', return_value=MagicMock( parallel_config=MagicMock(tensor_parallel_size=2), @@ -157,22 +110,20 @@ def mock_dist_env(mocker: MockerFixture): model_config=MagicMock(max_model_len=2048) )), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ - patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \ patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', - return_value=mock_forward_context_obj): + return_value=mock_forward_context_obj), \ + patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', + return_value=None), \ + patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', + return_value=None), \ + patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', + return_value=None): yield { 'mock_forward_context_obj': mock_forward_context_obj, - 'mock_token_dispatcher_with_allgather': - mock_token_dispatcher_with_allgather, - 'mock_token_dispatcher_with_all2allv': - mock_token_dispatcher_with_all2allv, - 'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2, + 'mock_moe_comm_method': mock_moe_comm_method, } - mock_register_token_dispatcher_patcher.stop() - mock_get_token_dispatcher_patcher.stop() - @pytest.fixture def mock_moe_env(mocker: MockerFixture): @@ -338,9 +289,7 @@ class TestAscendFusedMoe: moe.moe_parallel_config.ep_size = 1 moe.quant_method = MockQuantMethod(shared_experts, num_tokens) - forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens, - dtype=torch.bool), - padded_num_tokens=num_tokens) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): output = moe.forward(inputs, @@ -394,25 +343,10 @@ class TestAscendUnquantizedFusedMoEMethod: [[256, 4], [128, 1], [128, 1], [128, 4]]) def test_apply_without_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - global_num_experts, ep_size = others_param is_prefill = False - is_deepseek_v3_r1 = global_num_experts == 256 - if ep_size == 1: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_allgather'] - elif ep_size < 16: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_all2allv'] - else: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_mc2'] - - forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( - ep_size, is_prefill, is_deepseek_v3_r1), - with_quant=False, - token_dispatcher=selected_token_dispatcher) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): @@ -438,35 +372,22 @@ class TestAscendUnquantizedFusedMoEMethod: global_num_experts=global_num_experts, is_prefill=is_prefill) - expected_shape = (16, 2) + mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] + mock_moe_comm_method.fused_experts.assert_called_once() + expected_shape = (16, 2) assert result.shape == expected_shape @pytest.mark.parametrize("others_param", [16, 1, 4]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - ep_size = others_param is_prefill = False - if ep_size == 1: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_allgather'] - elif ep_size < 16: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_all2allv'] - else: - selected_token_dispatcher = mock_dist_env[ - 'mock_token_dispatcher_with_mc2'] - - forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( - ep_size, is_prefill, True), - with_quant=False, - token_dispatcher=selected_token_dispatcher) + forward_context = mock_dist_env['mock_forward_context_obj'] with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3): - expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) @@ -493,8 +414,10 @@ class TestAscendUnquantizedFusedMoEMethod: expert_map=expert_map, is_prefill=is_prefill) - expected_shape = (16, 2) + mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] + mock_moe_comm_method.fused_experts.assert_called_once() + expected_shape = (16, 2) assert result.shape == expected_shape @@ -574,7 +497,7 @@ class TestUnifiedApplyMLP(TestBase): mock_get_forward_context): mock_forward_context = MagicMock() - mock_forward_context.fused_moe_state = FusedMoEState.MC2 + mock_forward_context.moe_comm_method_name = "mc2commimpl" mock_get_forward_context.return_value = mock_forward_context mock_is_310p.return_value = False @@ -618,8 +541,6 @@ class TestUnifiedApplyMLP(TestBase): with_quant=True) mock_get_forward_context.assert_called() - self.assertEqual(mock_forward_context.fused_moe_state, - FusedMoEState.MC2) mock_npu_dynamic_quant.assert_called() diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 4273d267..9ba604f6 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -23,8 +23,7 @@ from tests.ut.base import TestBase from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip AscendSocVersion, TokenDispatcherWithAll2AllV, - TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers, - _register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers) + TokenDispatcherWithAllGather, TokenDispatcherWithMC2) class TestTokenDispatcherWithMC2(TestBase): @@ -521,99 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["group_list"]) self.assertEqual(result["group_list_type"], 1) - - -class TestDispatcherRegistry(TestBase): - - def setUp(self): - _Dispatchers.clear() - - def tearDown(self): - _Dispatchers.clear() - - def test_register_and_get_token_dispatcher(self): - mock_dispatcher = MagicMock() - mock_dispatcher.__class__.__name__ = "MockDispatcher" - - _register_token_dispatcher(mock_dispatcher) - - self.assertIn("MockDispatcher", _Dispatchers) - self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher) - - retrieved_dispatcher = get_token_dispatcher("MockDispatcher") - self.assertIs(retrieved_dispatcher, mock_dispatcher) - - self.assertIsNone(get_token_dispatcher("NonExistentDispatcher")) - - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather') - @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') - def test_setup_token_dispatchers_ep_size_1_creates_allgather( - self, mock_register, mock_allgather_class): - kwargs = {"top_k": 2, "num_experts": 8} - mock_instance = MagicMock() - mock_allgather_class.return_value = mock_instance - - self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers) - - setup_token_dispatchers(ep_size=1, **kwargs) - - mock_allgather_class.assert_called_once_with(**kwargs) - mock_register.assert_called_once_with(mock_instance) - - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') - @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') - def test_setup_token_dispatchers_ep_size_2_creates_all2allv( - self, mock_register, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2} - mock_instance = MagicMock() - mock_all2allv_class.return_value = mock_instance - - self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) - - setup_token_dispatchers(ep_size=2, **kwargs) - - mock_all2allv_class.assert_called_once_with(**kwargs) - mock_register.assert_called_once_with(mock_instance) - - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2') - @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') - def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2( - self, mock_register, mock_mc2_class, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} - mock_all2allv_instance = MagicMock() - mock_mc2_instance = MagicMock() - mock_all2allv_class.return_value = mock_all2allv_instance - mock_mc2_class.return_value = mock_mc2_instance - - self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) - self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers) - - setup_token_dispatchers(ep_size=16, **kwargs) - - mock_all2allv_class.assert_called_once_with(**kwargs) - mock_mc2_class.assert_called_once_with(**kwargs) - self.assertEqual(mock_register.call_count, 2) - mock_register.assert_any_call(mock_all2allv_instance) - mock_register.assert_any_call(mock_mc2_instance) - - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') - @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2') - @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') - def test_setup_token_dispatchers_ep_size_16_skips_if_exist( - self, mock_register, mock_mc2_class, mock_all2allv_class): - kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} - mock_existing_all2allv = MagicMock() - mock_existing_mc2 = MagicMock() - _Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv - _Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2 - - setup_token_dispatchers(ep_size=16, **kwargs) - - mock_all2allv_class.assert_not_called() - mock_mc2_class.assert_not_called() - mock_register.assert_not_called() - self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"], - mock_existing_all2allv) - self.assertIs(_Dispatchers["TokenDispatcherWithMC2"], - mock_existing_mc2) diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 5f7ad90b..9c116def 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -21,37 +21,31 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner # yapf: disable @pytest.mark.parametrize( - "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method", + "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", [ # Case 1: Expert parallel is disabled, should always be 'allgather' - (AscendSocVersion.A2, False, 8, 100, 256, "allgather"), - (AscendSocVersion.A3, False, 16, 500, 256, "allgather"), + (AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"), + (AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"), - # Case 2: A2 SOC - # 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16) - (AscendSocVersion.A2, True, 16, 100, 256, "mc2"), - (AscendSocVersion.A2, True, 32, 256, 256, "mc2"), - # 2.2: MC2 token capacity exceeded - (AscendSocVersion.A2, True, 16, 257, 256, "allgather"), - # 2.3: MC2 world size not met - (AscendSocVersion.A2, True, 8, 100, 256, "allgather"), - (AscendSocVersion.A2, True, 15, 100, 256, "allgather"), + # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2 + (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"), + (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"), + (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition - # Case 3: A3 SOC - # 3.1: MC2 condition met (tokens <= capacity) - (AscendSocVersion.A3, True, 8, 100, 256, "mc2"), - (AscendSocVersion.A3, True, 16, 256, 256, "mc2"), - # 3.2: MC2 token capacity exceeded - (AscendSocVersion.A3, True, 8, 257, 256, "alltoall"), - (AscendSocVersion.A3, True, 16, 500, 256, "alltoall"), + # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather + (AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"), + (AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"), + # Case 4: A3 SOC + (AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"), + (AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"), ]) # yapf: enable def test_select_moe_comm_method(soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, - expected_method): + quant_type, expected_method): """ - Tests the _select_moe_comm_method with various configurations. + Tests the _select_moe_comm_method with various configurations including quant_type. """ # Mock the NPUModelRunner instance and its dependencies mock_runner = MagicMock(spec=NPUModelRunner) @@ -60,15 +54,24 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel, mock_runner.parallel_config.world_size_across_dp = world_size mock_runner.mc2_tokens_capacity = mc2_tokens_capacity + # Add vllm_config.model_config.hf_config mock with moe_quantize + mock_hf_config = MagicMock() + mock_hf_config.moe_quantize = quant_type + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_runner.vllm_config = mock_vllm_config + # Patch the helper functions with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', return_value=soc_version), \ patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank', return_value=True): - # Call the method under test + # Bind the real method to the mock object method = NPUModelRunner._select_moe_comm_method( - mock_runner, num_tokens) + mock_runner, num_tokens, False) # Assert the result assert method == expected_method @@ -83,6 +86,15 @@ def test_select_moe_comm_method_unsupported_soc(): mock_runner.parallel_config.enable_expert_parallel = True mock_runner.mc2_tokens_capacity = 256 + # Add vllm_config.model_config.hf_config mock with moe_quantize + mock_hf_config = MagicMock() + mock_hf_config.moe_quantize = None + mock_model_config = MagicMock() + mock_model_config.hf_config = mock_hf_config + mock_vllm_config = MagicMock() + mock_vllm_config.model_config = mock_model_config + mock_runner.vllm_config = mock_vllm_config + unsupported_soc = "UnsupportedSOC" with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', @@ -91,4 +103,4 @@ def test_select_moe_comm_method_unsupported_soc(): return_value=True), \ pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"): - NPUModelRunner._select_moe_comm_method(mock_runner, 100) + NPUModelRunner._select_moe_comm_method(mock_runner, 100, False) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index d107a9ed..b368feb7 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -42,17 +42,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.MC2 -def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str: - if ep_size == 1: - return "TokenDispatcherWithAllGather" - elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1: - return "TokenDispatcherWithAllGather" - elif ep_size < 16 or with_prefill: - return "TokenDispatcherWithAll2AllV" - else: - return "TokenDispatcherWithMC2" - - @contextmanager def set_ascend_forward_context( attn_metadata: Any, @@ -97,11 +86,6 @@ def set_ascend_forward_context( forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run - from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher - dispatcher_name = get_dispatcher_name(ep_size, with_prefill) - dispatcher = get_token_dispatcher(dispatcher_name) - forward_context.token_dispatcher = dispatcher - # NOTE: This cannot be set using set_forward_context # due to multiple warmups before actual capturing forward_context.capturing = False diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 773684e5..33073f43 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -32,8 +32,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, MC2CommImpl) -from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers + AlltoAllCommImpl, MC2CommImpl, + NaiveMulticastCommImpl) from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ @@ -280,17 +280,17 @@ class AscendFusedMoE(FusedMoE): num_redundant_experts, has_bias, ) - setup_token_dispatchers(self.moe_config.ep_size, - top_k=self.top_k, - num_experts=self.global_num_experts, - num_local_experts=self.local_num_experts) + self.hidden_size = hidden_size self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}: + for method in { + AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, + NaiveMulticastCommImpl + }: setattr( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 90582d64..76b677a4 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -19,13 +19,10 @@ import os from typing import Any, Callable, Optional import torch -import torch.distributed as dist import torch_npu -from torch import nn from vllm.config import get_current_vllm_config from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -39,72 +36,18 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, + AlltoAllCommImpl, MC2CommImpl, + NaiveMulticastCommImpl) from vllm_ascend.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, get_all_reduce_merge_state, get_rm_router_logits_state, is_310p) -def unified_fused_experts_eager(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - w1_scale: Optional[torch.Tensor] = None, - w1_scale_bias: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w2_scale_bias: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - fusion_mlp: bool = False): - token_dispatcher = get_forward_context().token_dispatcher - - results = token_dispatcher.token_dispatch( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - expert_map=expert_map, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=mc2_mask, - apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=with_quant) - - expert_output = unified_apply_mlp( - hidden_states=results["hidden_states"], - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=results["group_list"], - dynamic_scale=results.get("dynamic_scale"), - group_list_type=results.get("group_list_type"), - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=results.get("topk_scales"), - with_quant=with_quant, - fusion=fusion_mlp) - final_hidden_states = token_dispatcher.token_combine(expert_output) - return final_hidden_states - - class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, moe: FusedMoEConfig = None): @@ -182,17 +125,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): if enable_force_load_balance and not self.use_aclgraph: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - return unified_fused_experts_eager(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - expert_map=expert_map, - shared_experts=shared_experts, - mc2_mask=kwargs.get( - "mc2_mask", None), - with_quant=False) + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + global_num_experts=global_num_experts, + expert_map=expert_map, + shared_experts=shared_experts, + need_trans=True) class AscendFusedMoE(FusedMoE): @@ -354,18 +298,20 @@ class AscendFusedMoE(FusedMoE): # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) - self.token_dispatcher = None - ep_size = (get_ep_group().world_size if - vllm_config.parallel_config.enable_expert_parallel else 1) - from vllm_ascend.ops.moe.token_dispatcher import \ - setup_token_dispatchers - setup_token_dispatchers( - ep_size, - top_k=self.top_k, - num_experts=self.global_num_experts, - num_global_redundant_experts=self.global_redundant_expert_num, - num_local_experts=self.local_num_experts) + self.moe_config.tp_group = get_tp_group() + self.moe_config.dp_group = get_dp_group() + self.moe_config.ep_group = get_ep_group() + self.moe_config.mc2_group = get_mc2_group() + self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num + + for method in { + AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, + NaiveMulticastCommImpl + }: + setattr( + self, method.__name__.lower(), + method(moe_config=self.moe_config)) # type: ignore[abstract] def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -401,10 +347,7 @@ class AscendFusedMoE(FusedMoE): else: real_top_k = self.top_k - num_tokens, hidden_size = hidden_states.shape - forward_context = get_forward_context() - fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None @@ -422,63 +365,16 @@ class AscendFusedMoE(FusedMoE): mc2_mask = chunk_mc2_mask[tp_rank] replace_allreduce = True - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ] and not replace_allreduce): - if fused_moe_state in {FusedMoEState.MC2}: - padding_size = forward_context.padded_num_tokens - else: - # TODO: Determine if we can remove the padding - padding_size = tp_size - if num_tokens < padding_size and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, padding_size - num_tokens)) - router_logits = nn.functional.pad( - router_logits, (0, 0, 0, padding_size - num_tokens)) - if tp_size > 1: - tp_rank = get_tensor_model_parallel_rank() - if not self.enable_shared_expert_dp: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] + moe_comm_method_name = forward_context.moe_comm_method_name + forward_context.moe_comm_method = getattr(self, moe_comm_method_name) - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] - - if self.dp_size > 1: - if fused_moe_state == FusedMoEState.AllGather: - # NOTE: When in torchair graph, it has been padded in model_runner_v1 - max_tokens_across_dp = forward_context.max_tokens_across_dp - if num_tokens < max_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - if not self.rm_router_logits: - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - hidden_states = get_dp_group().all_gather(hidden_states, 0) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = get_dp_group().all_gather(router_logits, 0) - - elif fused_moe_state == FusedMoEState.NaiveMulticast: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_cpu) + hidden_states, router_logits = forward_context.moe_comm_method.prepare( + hidden_states=hidden_states, + router_logits=router_logits, + enable_shared_expert_dp=self.enable_shared_expert_dp, + rm_router_logits=self.rm_router_logits, + replace_allreduce=replace_allreduce, + gate=gate) # Matrix multiply. e_hidden_states = self.quant_method.apply( @@ -501,7 +397,6 @@ class AscendFusedMoE(FusedMoE): global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=None, mc2_mask=mc2_mask, - token_dispatcher=self.token_dispatcher, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, ) @@ -510,44 +405,9 @@ class AscendFusedMoE(FusedMoE): if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ] and not replace_allreduce and not self.enable_shared_expert_dp): - if tp_size > 1: - dist.all_gather(list(chunk_hidden_states), e_hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - if num_tokens < padding_size: - final_hidden_states = final_hidden_states[:num_tokens] - elif self.dp_size > 1 and not self.enable_shared_expert_dp: - if fused_moe_state == FusedMoEState.NaiveMulticast: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - final_hidden_states = get_dp_group().all_reduce( - e_hidden_states) - final_hidden_states = final_hidden_states[start:end, :] - dispose_tensor(e_hidden_states) - elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = get_dp_group().reduce_scatter( - e_hidden_states, 0) - final_hidden_states = final_hidden_states[:num_tokens] - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - else: - final_hidden_states = e_hidden_states - - if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ]: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = forward_context.moe_comm_method.finalize( + hidden_states=e_hidden_states, + reduce_results=(not self.all_reduce_merge)) if shared_experts: return final_hidden_states, shared_hidden_states diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index b07c4897..bc0d4fb5 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -28,6 +28,16 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig class FusedMoEPrepareAndFinalize(ABC): + """ + Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization + in distributed environments. Subclasses implement specific communication strategies + (e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing, + broadcasting, and reduction across TP/DP/EP groups. + + Attributes: + moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, + sizes, ranks, and communication settings. + """ def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config @@ -40,22 +50,65 @@ class FusedMoEPrepareAndFinalize(ABC): rm_router_logits: bool = False, replace_allreduce: bool = False, gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare tensors before MoE computation. May involve: + - Padding to align communication boundaries + - Slicing across tensor-parallel ranks + - Broadcasting across data-parallel ranks + - Recomputing router logits if needed + + Args: + hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size] + router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] + enable_shared_expert_dp (bool): Skip DP communication for shared experts + rm_router_logits (bool): Discard input router_logits and recompute via gate + replace_allreduce (bool): Bypass default all-reduce behavior + gate (nn.Module, optional): Gate network to recompute router_logits if needed + + Returns: + Tuple of: + - processed hidden_states (may be padded/sliced/broadcasted) + - processed router_logits (may be recomputed or broadcasted) + - optional communication mask (e.g., mc2_mask for sparse ops) + """ raise NotImplementedError("Prepare not implemented.") def finalize(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor: - raise NotImplementedError("Combine function not implemented.") + """ + Finalize MoE output. May involve: + - Gathering sliced tensors across TP ranks + - Reducing or scattering across DP ranks + - Unpadding to original token count + - Applying all-reduce across TP/EP if requested + + Args: + hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced + reduce_results (bool): Whether to apply all-reduce across TP/EP groups + + Returns: + torch.Tensor: Final output with shape [original_num_tokens, hidden_size] + """ + raise NotImplementedError("Finalize function not implemented.") class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using MC2 (Memory-Centric Communication). + Designed for Ascend or environments requiring explicit padding and slicing control. + Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. + """ def __init__(self, moe_config: FusedMoEConfig): super().__init__(moe_config) self._restore_tp_across_dp() def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. + """ + Restore original TP configuration. + vLLM flattens TP and DP into a single dimension; this method recovers + the true TP world size and rank for correct tensor slicing. + """ self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -66,9 +119,17 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): rm_router_logits: bool = False, replace_allreduce: bool = False, gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """The target_pad_length is calculated in forward_context, here we pad the - hidden states and router logits. And if TP size > 1, we also need to split - the tensors accordingly. + """ + Preparation steps: + 1. Fetch `mc2_mask` and target padding length from forward context. + 2. Pad `hidden_states` and `router_logits` to target length if needed. + 3. If TP > 1, split tensors along token dimension and select current TP rank's slice. + 4. Split and return corresponding `mc2_mask`. + + Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp @@ -80,12 +141,14 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): target_pad_length = forward_context.padded_num_tokens pad_size = target_pad_length - self.num_tokens + # Pad if necessary (unless shared expert DP is enabled) if pad_size > 0 and not self.enable_shared_expert_dp: hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) + # Slice across TP ranks if self.tp_size > 1: if not self.enable_shared_expert_dp: split_hidden_states = torch.tensor_split(hidden_states, @@ -96,8 +159,9 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): dim=0) hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - self.split_hidden_states = split_hidden_states + self.split_hidden_states = split_hidden_states # Save for finalize + # Also slice mc2_mask split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) @@ -107,16 +171,22 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): def finalize(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. - - Also, unpad the hidden states if needed. + """ + Finalization steps: + 1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor. + 2. Unpad to original token count if padding was applied. + 3. Return tensor with shape [original_num_tokens, hidden_size]. + + Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True. """ if not (self.enable_shared_expert_dp or self.replace_allreduce): if self.tp_size > 1: + # All-gather across TP group dist.all_gather(list(self.split_hidden_states), hidden_states, self.moe_config.tp_group.device_group) hidden_states = torch.cat(self.split_hidden_states, dim=0) + # Unpad if necessary if self.num_tokens < hidden_states.shape[0]: hidden_states = hidden_states[:self.num_tokens] @@ -124,14 +194,18 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-to-All style slicing. + Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. + Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). + """ def __init__(self, moe_config: FusedMoEConfig): super().__init__(moe_config) self._restore_tp_across_dp() def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. + """Restore original TP configuration (same as MC2).""" self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -142,12 +216,23 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): rm_router_logits: bool = False, replace_allreduce: bool = False, gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Pad hidden_states and router_logits to next multiple of TP size. + 2. If TP > 1, split along token dim and select current TP rank's slice. + 3. Save splits for later all-gather in finalize. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, None) — no mask used in All2All. + """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp if not (self.replace_allreduce or self.enable_shared_expert_dp): self.num_tokens, _ = hidden_states.shape - pad_size = self.tp_size - self.num_tokens + pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) if pad_size > 0: hidden_states = nn.functional.pad(hidden_states, @@ -171,9 +256,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): def finalize(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. + """ + Finalization steps: + 1. If TP > 1, all-gather slices to reconstruct full tensor. + 2. Unpad to original token count. + 3. Return [original_num_tokens, hidden_size] tensor. - Also, unpad the hidden states if needed. + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. """ if not (self.enable_shared_expert_dp or self.replace_allreduce): if self.tp_size > 1: @@ -188,6 +277,11 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using All-Gather + Reduce-Scatter. + Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after. + Uses `max_tokens_across_dp` from forward_context for padding alignment. + """ def prepare(self, hidden_states: torch.Tensor, @@ -196,8 +290,16 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): rm_router_logits: bool = False, replace_allreduce: bool = False, gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """When DP size > 1, pad the hidden states and router logits for communication.""" - self.rm_router_logits = rm_router_logits + """ + Preparation steps: + 1. Fetch max token count across DP group from forward context. + 2. Pad local tensors to that size. + 3. All-gather across DP group to form global input tensor. + 4. Optionally recompute router_logits using gate if `rm_router_logits=True`. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ self.enable_shared_expert_dp = enable_shared_expert_dp if self.moe_config.dp_size > 1: @@ -209,14 +311,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): if pad_size > 0: hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) - if not self.rm_router_logits: + if not rm_router_logits: router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) + # All-gather across DP group hidden_states = self.moe_config.dp_group.all_gather( hidden_states, 0) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) + if rm_router_logits: + router_logits, _ = gate(hidden_states) # Recompute globally else: router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) @@ -225,9 +328,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): def finalize(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor: - """When DP size > 1, reduce-scatter the hidden states to get the final output. + """ + Finalization steps: + 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. + 2. Slice to original local token count. + 3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. - When TP size > 1, all-reduce the hidden states to get the final output. + Returns: + Tensor with shape [original_local_num_tokens, hidden_size] """ if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) @@ -238,3 +346,101 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): hidden_states = tensor_model_parallel_all_reduce(hidden_states) return hidden_states + + +class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): + """ + MoE communication strategy using Naive Multicast (point-to-point broadcast). + Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. + Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries. + """ + + def _naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + """ + Naive multicast implementation: + 1. Create global buffer sized by total tokens across DP. + 2. Current rank copies its slice into its designated buffer region. + 3. Each rank broadcasts its slice to all others via P2P. + + Args: + x (torch.Tensor): Local tensor [local_tokens, hidden_size] + cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank + + Returns: + torch.Tensor: Global tensor [total_tokens, hidden_size] + """ + assert len(x.shape) == 2, "Input must be 2D [tokens, features]" + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + # Copy local slice into buffer + start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + buffer[start:end, :].copy_(x) + + # Broadcast each slice to all ranks + for idx in range(self.moe_config.dp_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preparation steps: + 1. Fetch cumulative token boundaries from forward context. + 2. Multicast hidden_states and router_logits to form global tensors. + 3. Optionally recompute router_logits globally if `rm_router_logits=True`. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + self.enable_shared_expert_dp = enable_shared_expert_dp + + if self.moe_config.dp_size > 1: + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + hidden_states = self._naive_multicast(hidden_states, + self.cu_tokens_across_dp_cpu) + if rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = self._naive_multicast( + router_logits, self.cu_tokens_across_dp_cpu) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ + Finalization steps: + 1. If DP > 1 and not shared expert: + - All-reduce across DP + - Slice to current rank's token range using cu_tokens_across_dp_cpu + 2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce. + + Returns: + Tensor with shape [local_num_tokens, hidden_size] + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[ + self.moe_config.dp_rank - 1] + end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank] + hidden_states = get_dp_group().all_reduce( + hidden_states) # Sum across DP + hidden_states = hidden_states[start:end, :] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index af46a3fb..2194f4f7 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -23,7 +23,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( FusedMoEPrepareAndFinalizeWithAll2All, - FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2) + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, + FusedMoEPrepareAndFinalizeWithNaiveMulticast) from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, @@ -82,8 +83,6 @@ class MoECommMethod(ABC): is_torchair: bool = False, # For Cube/Vector parallel shared_experts: Optional[Any] = None, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, # For load balance @@ -91,13 +90,6 @@ class MoECommMethod(ABC): global_redundant_expert_num: int = 0, need_trans: bool = False) -> torch.Tensor: # Check constraints - assert hidden_states.shape[1] == w1.shape[1], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] @@ -114,8 +106,8 @@ class MoECommMethod(ABC): log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=self.mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8) @@ -135,12 +127,13 @@ class MoECommMethod(ABC): w2_scale_bias=w2_scale_bias, with_quant=use_int8_w8a8 or use_int4_w4a8, + fusion=use_int8_w8a8, need_trans=need_trans) - hidden_states[:] = self.token_dispatcher.token_combine( + final_hidden_states = self.token_dispatcher.token_combine( hidden_states=mlp_output) - return hidden_states + return final_hidden_states @abstractmethod def _get_token_dispatcher(self): @@ -296,3 +289,32 @@ class AlltoAllCommImpl(MoECommMethod): def _get_fused_moe_prepare_finalize(self): return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + + +class NaiveMulticastCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 77e83184..b1567b07 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -21,7 +21,6 @@ import torch_npu from torch.nn.functional import pad from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.utils import dispose_tensor, is_310p @@ -77,7 +76,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias1, bias2 = None, None _output_dtype = w2_scale.dtype - is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 + is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl" if w1_scale_bias is None and is_mc2: if w1_scale.dtype != torch.float32: w1_scale = w1_scale.to(torch.float32) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index f3aba2be..c1a16d02 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -22,45 +22,17 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group -import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version -_Dispatchers: Dict[str, Any] = {} - - -def _register_token_dispatcher(dispatcher: Any): - _Dispatchers[dispatcher.__class__.__name__] = dispatcher - - -def get_token_dispatcher(name: str): - return _Dispatchers.get(name) - - -def setup_token_dispatchers(ep_size: int, **kwargs): - existing_dispatchers = set(_Dispatchers.keys()) - - if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) - elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \ - and "TokenDispatcherWithAllGather" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) - elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) - elif ep_size >= 16: - if "TokenDispatcherWithAll2AllV" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) - if "TokenDispatcherWithMC2" not in existing_dispatchers: - _register_token_dispatcher(TokenDispatcherWithMC2(**kwargs)) - class MoETokenDispatcher(ABC): @@ -93,9 +65,9 @@ class MoETokenDispatcher(ABC): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -192,9 +164,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -218,6 +190,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): if self.with_quant: if shared_experts is not None: + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + shared_act_out = shared_experts.act_fn( (shared_gate_up, shared_dequant_scale)) self.shared_act, self.swiglu_out_scale = \ @@ -331,9 +308,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -418,9 +395,9 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): @@ -530,9 +507,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[torch.Tensor] = None, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 47aa99cf..0de60b70 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -24,9 +24,7 @@ from vllm.config import get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.moe.experts_selector import select_experts @@ -275,14 +273,6 @@ class AscendW4A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -291,7 +281,8 @@ class AscendW4A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - return unified_fused_experts_eager( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -302,14 +293,13 @@ class AscendW4A8DynamicFusedMoEMethod: topk_weights=topk_weights, topk_ids=topk_ids, row_idx=row_idx, + use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None), - with_quant=True) + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): group_num, k, n = weight.shape diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 13839e9c..c34140f3 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -24,9 +24,7 @@ from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ @@ -233,14 +231,6 @@ class AscendW8A8DynamicFusedMoEMethod: expert_map=expert_map, ) - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -249,7 +239,8 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - return unified_fused_experts_eager( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale_fp32, @@ -258,15 +249,13 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights=topk_weights, topk_ids=topk_ids, row_idx=row_idx, + use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None), - with_quant=True, - fusion_mlp=True) + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share) def process_weights_after_loading(self, layer): if self.transpose_weight: diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index f627f23c..9fda89a9 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -117,8 +117,11 @@ class EagleProposer(Proposer): skip_attn: bool = False, num_reqs: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None): + moe_comm_method = self.runner._select_moe_comm_method( + num_tokens, with_prefill) with set_ascend_forward_context(None, self.vllm_config, + moe_comm_method=moe_comm_method, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], @@ -447,12 +450,20 @@ class EagleProposer(Proposer): num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens + + with_prefill = attn_metadata.attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + moe_comm_method = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions.to(device) self.hidden_states[:num_tokens] = target_hidden_states attn_metadata.block_tables = block_table.to(device) with set_ascend_forward_context(attn_metadata, self.vllm_config, + moe_comm_method=moe_comm_method, num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -483,6 +494,10 @@ class EagleProposer(Proposer): input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size + + moe_comm_method = self.runner._select_moe_comm_method( + input_batch_size, False) + attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -553,6 +568,7 @@ class EagleProposer(Proposer): # Run the model. with set_ascend_forward_context(attn_metadata, self.vllm_config, + moe_comm_method=moe_comm_method, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 0729b0a2..4749bac8 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -112,6 +112,10 @@ class MtpProposer(Proposer): (num_tokens, num_tokens_across_dp, with_prefill, _) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill, False) + + moe_comm_method = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + is_running_torchair = self.torchair_graph_enabled and \ not with_prefill @@ -142,6 +146,7 @@ class MtpProposer(Proposer): with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_method=moe_comm_method, in_profile_run=self.runner.in_profile_run, num_actual_tokens=0): if is_running_torchair: @@ -411,6 +416,9 @@ class MtpProposer(Proposer): num_tokens_across_dp = self.runner.num_tokens_across_dp with_prefill = self.runner.with_prefill + moe_comm_method = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + for step in range(self.num_speculative_tokens): with set_ascend_forward_context( attn_metadata, @@ -419,6 +427,7 @@ class MtpProposer(Proposer): with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_method=moe_comm_method, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): with ProfileExecuteDuration().capture_async('mtp_forward'): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index af122a37..a409bd3e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1663,7 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_connector_output=kv_connector_output, ) - def _select_moe_comm_method(self, num_tokens: int) -> str: + def _select_moe_comm_method(self, num_tokens: int, + with_prefill: bool) -> str: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all are designed for expert parallelism. 2. If expert parallel is enabled, we need to consider the soc version and the @@ -1687,6 +1688,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): str: The selected MoE communication method, either "allgather", "mc2", or "alltoall". """ soc_version = get_ascend_soc_version() + quant_type = getattr(self.vllm_config.model_config.hf_config, + 'moe_quantize', None) if not self.parallel_config.enable_expert_parallel: moe_comm_method = "allgather" @@ -1694,12 +1697,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16: moe_comm_method = "mc2" else: - moe_comm_method = "allgather" + if quant_type == "w4a8_dynamic": + moe_comm_method = "alltoall" + else: + moe_comm_method = "allgather" + elif soc_version in {AscendSocVersion.A3}: moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall" else: raise ValueError(f"Unsupported soc_version: {soc_version}") + if moe_comm_method == "allgather" and with_prefill: + moe_comm_method = "naivemulticast" + if is_global_first_rank(): logger.debug(f"num_tokens: {num_tokens}, " f"moe_comm_method: {moe_comm_method}") @@ -1728,7 +1738,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) - moe_comm_method = self._select_moe_comm_method(num_input_tokens) + moe_comm_method = self._select_moe_comm_method(num_input_tokens, + self.with_prefill) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) @@ -2100,7 +2111,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) - moe_comm_method = self._select_moe_comm_method(num_tokens) + moe_comm_method = self._select_moe_comm_method(num_tokens, + with_prefill) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using