diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index f1be625..6490e9c 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -279,7 +279,6 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index f7354ab..17c3410 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -108,14 +108,13 @@ def test_models_distributed_pangu(): ] max_tokens = 5 - with VllmRunner( - snapshot_download("vllm-ascend/pangu-pro-moe-pruing"), - max_model_len=8192, - enforce_eager=True, - dtype="auto", - tensor_parallel_size=2, - distributed_executor_backend="mp", - ) as vllm_model: + with VllmRunner(snapshot_download("vllm-ascend/pangu-pro-moe-pruing"), + max_model_len=8192, + enforce_eager=True, + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) @@ -141,28 +140,6 @@ def test_models_distributed_topk() -> None: vllm_model.generate(example_prompts, sampling_params) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"}) -def test_models_distributed_alltoallv() -> None: - example_prompts = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", - ] - dtype = "half" - sampling_params = SamplingParams(max_tokens=5, - temperature=0.0, - top_k=50, - top_p=0.9) - - with VllmRunner( - "deepseek-ai/DeepSeek-V2-Lite", - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend="mp", - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - def test_models_distributed_Qwen3_W8A8(): example_prompts = [ "Hello, my name is", diff --git a/tests/ut/ops/test_common_fused_moe.py b/tests/ut/ops/test_common_fused_moe.py new file mode 100644 index 0000000..409a301 --- /dev/null +++ b/tests/ut/ops/test_common_fused_moe.py @@ -0,0 +1,69 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from unittest.mock import patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.ops.common_fused_moe import fused_experts_moge + + +class TestFusedExpertsMoGE(TestBase): + + def test_fused_experts_moge(self): + with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \ + patch('torch_npu.npu_swiglu') as mock_swiglu, \ + patch('vllm_ascend.utils.is_310p') as mock_is_310p: + + mock_is_310p.return_value = False + + mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [ + torch.randn(x[0].shape[0], weight[0].shape[1]) + ] + + mock_swiglu.side_effect = lambda x: x + + hidden_states = torch.randn(4, 128) + w1 = torch.randn(4, 256, 128) + w2 = torch.randn(4, 128, 128) + topk_weights = torch.rand(4, 1) + topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long) + top_k = 1 + global_num_experts = 4 + + moe_parallel_config = type( + 'MockConfig', (), { + 'ep_size': 1, + 'tp_size': 1, + 'dp_size': 1, + 'tp_rank': 0, + 'dp_rank': 0, + 'ep_rank': 0, + 'use_ep': True + })() + + output = fused_experts_moge( + hidden_states=hidden_states, + w1=w1, + w2=w2, + moe_parallel_config=moe_parallel_config, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + apply_router_weight_on_input=True, + ) + + self.assertEqual(output.shape, (4, 128)) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 8c4c7f4..6a51d1d 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -27,9 +27,9 @@ from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import (FusedMoEState, _get_fused_moe_state) from vllm_ascend.ops.fused_moe import (AscendFusedMoE, - AscendUnquantizedFusedMoEMethod, - unified_apply_mlp) + AscendUnquantizedFusedMoEMethod) from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp from vllm_ascend.utils import AscendSocVersion, adapt_patch adapt_patch(True) @@ -129,36 +129,38 @@ def mock_dist_env(mocker: MockerFixture): with_quant=False) with patch('torch.distributed.get_rank', return_value=0), \ - patch('torch.distributed.get_world_size', return_value=4), \ - patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('torch.distributed.all_gather'), \ - patch('torch.distributed.all_to_all_single'), \ - patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ - patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \ - patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', - return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.ops.fused_moe.get_ascend_config', - return_value=MagicMock( - torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), - expert_map_path=None - )), \ - patch('vllm_ascend.ops.fused_moe.determine_expert_map', - return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ - patch('vllm_ascend.ops.fused_moe.get_forward_context', - return_value=mock_forward_context_obj), \ + patch('torch.distributed.get_world_size', return_value=4), \ + patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('torch.distributed.all_gather'), \ + patch('torch.distributed.all_to_all_single'), \ + patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ + patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \ + patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', + return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_ascend_config', + return_value=MagicMock( + torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), + expert_map_path=None + )), \ + patch('vllm_ascend.ops.fused_moe.determine_expert_map', + return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ + patch('vllm_ascend.ops.fused_moe.get_forward_context', + return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', - return_value=MagicMock( - parallel_config=MagicMock(tensor_parallel_size=2), - scheduler_config=MagicMock(max_num_seqs=4), - model_config=MagicMock(max_model_len=2048) - )), \ + return_value=MagicMock( + parallel_config=MagicMock(tensor_parallel_size=2), + scheduler_config=MagicMock(max_num_seqs=4), + model_config=MagicMock(max_model_len=2048) + )), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ - patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers): + patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \ + patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context', + return_value=mock_forward_context_obj): yield { 'mock_forward_context_obj': mock_forward_context_obj, @@ -441,12 +443,11 @@ class TestAscendUnquantizedFusedMoEMethod: assert result.shape == expected_shape - @pytest.mark.parametrize("others_param", - [[16, False], [1, True], [1, False], [4, False]]) + @pytest.mark.parametrize("others_param", [16, 1, 4]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - ep_size, alltoall_buffer = others_param + ep_size = others_param is_prefill = False if ep_size == 1: @@ -464,9 +465,7 @@ class TestAscendUnquantizedFusedMoEMethod: with_quant=False, token_dispatcher=selected_token_dispatcher) - with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", - alltoall_buffer), \ - patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ + with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3): expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) @@ -475,8 +474,6 @@ class TestAscendUnquantizedFusedMoEMethod: if ep_size == 1: x = x.view(-1, 2) router_logits = torch.randn(8, 8) - if alltoall_buffer: - moe_method.max_model_len = 1 layer = MagicMock() local_num_experts = 2 @@ -529,9 +526,8 @@ class TestExpertsSelector: class TestUnifiedApplyMLP(TestBase): - @patch('vllm_ascend.ops.fused_moe.get_forward_context') - @patch('vllm_ascend.ops.fused_moe.get_mc2_group') - @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dequant_swiglu_quant') @@ -539,16 +535,12 @@ class TestUnifiedApplyMLP(TestBase): mock_npu_dynamic_quant, mock_npu_grouped_matmul, mock_is_310p, - mock_get_mc2_group, mock_get_forward_context): mock_forward_context = MagicMock() mock_forward_context.fused_moe_state = FusedMoEState.MC2 mock_get_forward_context.return_value = mock_forward_context - mock_mc2_group = MagicMock() - mock_get_mc2_group.return_value = mock_mc2_group - mock_is_310p.return_value = False mock_npu_dynamic_quant.return_value = (torch.randint(-128, @@ -601,7 +593,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -643,7 +635,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch('vllm_ascend.ops.fused_moe.get_forward_context') + @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -703,7 +695,7 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index be0a4f9..9de8a13 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -17,57 +17,13 @@ from unittest.mock import MagicMock, PropertyMock, patch -import pytest import torch -from pytest_mock import MockerFixture -from tests.ut.base import PytestBase, TestBase +from tests.ut.base import TestBase from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, - TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher, - get_token_dispatcher, setup_token_dispatchers) - - -class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): - - @pytest.fixture - def config(self): - config = MoEDispatcherConfig() - config.set_num_local_experts(2) - config.set_num_moe_experts(4) - config.set_moe_pad_expert_input_to_capacity(False) - config.set_moe_expert_capacity_factor(None) - config.set_moe_router_topk(2) - config.set_moe_grouped_gemm(False) - config.set_group_topk(0) - config.set_num_groups(1) - config.set_is_fused(False) - return config.build() - - def mock_ep_group(self, mocker): - mock_group = mocker.MagicMock() - mock_group.rank_in_group = 0 - mock_group.world_size = 2 - mock_group.device_group = "mock_group" - return mock_group - - @pytest.fixture - def dispatcher(self, config, mocker: MockerFixture): - mocker.patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", - return_value=self.mock_ep_group(mocker)) - mocker.patch("torch.npu.current_device", return_value="cpu") - mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock) - return MoEAlltoAllSeqOverLapDispatcher(config) - - def test_initialization(self, dispatcher, config): - assert dispatcher.num_local_experts == config.num_local_experts - assert dispatcher.num_experts == config.num_moe_experts - assert dispatcher.local_expert_indices == [0, 1] - assert dispatcher.ep_rank == 0 - assert dispatcher.ep_size == 2 - assert dispatcher.overlap_stream is not None + AscendSocVersion, TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers, + _register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers) class TestTokenDispatcherWithMC2(TestBase): diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 149f6d4..7a6fc10 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -353,8 +353,7 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: else: assert result.shape == x.shape - @pytest.mark.parametrize("others_param", - [[16, False], [1, True], [1, False], [4, False]]) + @pytest.mark.parametrize("others_param", [16, 1, 4]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): """ @@ -363,13 +362,11 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: 3 test use_select_experts and fused_experts_with_all2all 4 test use_select_experts and fused_experts """ - ep_size, alltoall_buffer = others_param + ep_size = others_param is_prefill = False forward_context = MagicMock( fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) - with patch("vllm_ascend.torchair.ops.torchair_fused_moe.MOE_ALL2ALL_BUFFER", - alltoall_buffer), \ - patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ + with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) moe_method.ep_size = ep_size @@ -377,8 +374,6 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: if ep_size == 1: x = x.view(-1, 2) router_logits = torch.randn(8, 8) - if alltoall_buffer: - moe_method.max_model_len = 1 layer = MagicMock() layer.w13_weight = torch.randn(8, 16, 1) layer.w2_weight = torch.randn(16, 8, 1) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 7ddbc82..601f33a 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -35,10 +35,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.NaiveMulticast else: return FusedMoEState.AllGather - elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: - # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. - return (FusedMoEState.All2AllSeq if - (ep_size < 16 or with_prefill) else FusedMoEState.MC2) # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8469297..625b65a 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -90,11 +90,6 @@ env_variables: Dict[str, Callable[[], Any]] = { "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), - # MOE_ALL2ALL_BUFFER: - # 0: default, normal init. - # 1: enable moe_all2all_buffer. - "MOE_ALL2ALL_BUFFER": - lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))), # Some models are optimized by vllm ascend. While in some case, e.g. rlhf # training, the optimized model may not be suitable. In this case, set this # value to False to disable the optimized model. @@ -136,11 +131,6 @@ env_variables: Dict[str, Callable[[], Any]] = { # this feature is supported in A2, and eager mode will get better performance. "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), - # Whether to enable the alltoall_seq flag, this provides a basic framework on the basis of alltoall for easy expansion. - # 0: default, normal init. - # 1: enable moe all2all seq. - "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": - lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))), # Whether to enable mlp optimize when tensor parallel is enabled. # this feature in eager mode will get better performance. "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index a44ab68..7265113 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -22,6 +22,8 @@ import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import \ + FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod) @@ -30,7 +32,6 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl) from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import fused_experts_moge from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ setup_token_dispatchers @@ -139,6 +140,95 @@ def fused_experts( return hidden_states +def fused_experts_moge( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + moe_parallel_config: FusedMoEParallelConfig, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + """ + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). + + Returns: + hidden_states: Hidden states after routing. + """ + ep_size = moe_parallel_config.ep_size + local_num_experts = global_num_experts // ep_size + local_num_group = top_k // ep_size + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + + bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + sorted_topk_ids = sorted_topk_ids.to(torch.int32) + sorted_hidden_states = hidden_states.index_select( + 0, sorted_topk_ids // local_num_group) + + experts_id = torch.arange(0, + local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + topk_scales = topk_weights.view(-1).index_select( + 0, sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + + gate_up_out = torch_npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + )[0] + + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out *= topk_scales + + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=group_list, + )[0] + + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) + unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + bsz, top_k // ep_size, -1).sum(1) + + return final_hidden_states + + def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index a84c104..e86f77d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.communication_op import \ @@ -46,397 +45,12 @@ from vllm_ascend.distributed.communication_op import \ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) +from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor, get_all_reduce_merge_state, get_rm_router_logits_state, is_310p) -MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER - - -def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, - max_row_per_ep_rank: int, num_tokens: int, - top_k: int) -> tuple[torch.Tensor, torch.Tensor]: - original_total_elements = num_tokens * top_k - device = topk_ids.device - original_dtype = topk_ids.dtype - - if original_total_elements == 0: - output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) - unpad_indices = torch.full((original_total_elements, ), - -1, - dtype=torch.long, - device=device) - return topk_ids_pad, unpad_indices - - experts_per_ep_rank_val = expert_num // ep_size - if experts_per_ep_rank_val == 0: - raise ValueError( - "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " - "Ensure expert_num >= ep_size.") - - assigned_ep_rank = (topk_ids.float() / - experts_per_ep_rank_val).to(original_dtype) - indices_arange = torch.arange(topk_ids.shape[0], device=device) - - is_new_segment = torch.cat( - (torch.tensor([True], device=device), assigned_ep_rank[1:] - != assigned_ep_rank[:-1])) - temp_start_markers = torch.full_like(indices_arange, - -1, - dtype=indices_arange.dtype) - temp_start_markers[is_new_segment] = indices_arange[is_new_segment] - start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] - token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token - is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank - cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) - indices_in_rec_cond_list_for_all = cumsum_kept - 1 - unpad_indices = torch.where( - is_kept_mask, indices_in_rec_cond_list_for_all, - torch.tensor(-1, device=device, dtype=torch.long)) - output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) - if topk_ids.shape[0] > 0: - all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - temp_pad_buffer = torch.full((output_len + 1, ), - expert_num, - dtype=original_dtype, - device=device) - output_len_tensor = torch.tensor(output_len, - dtype=torch.long, - device=device) - scatter_indices = torch.where(is_kept_mask, all_destination_indices, - output_len_tensor) - temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) - topk_ids_pad = temp_pad_buffer[:output_len] - return topk_ids_pad, unpad_indices - - -def apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, -) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - - Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - - Returns: - hidden_states: output hidden states after MLP. - """ - - w1 = w1.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - - hidden_states = torch_npu.npu_swiglu(hidden_states) - - w2 = w2.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - - return hidden_states - - -def fused_experts_moge( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - moe_parallel_config: FusedMoEParallelConfig, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - ep_size = moe_parallel_config.ep_size - local_num_experts = global_num_experts // ep_size - local_num_group = top_k // ep_size - - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - - bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - sorted_topk_ids = sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, sorted_topk_ids // local_num_group) - - experts_id = torch.arange(0, - local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - - gate_up_out = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - if is_310p(): - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out *= topk_scales - - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) - unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - bsz, top_k // ep_size, -1).sum(1) - - return final_hidden_states - - -def quant_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - bias1, bias2 = None, None - _output_dtype = w2_scale.dtype - - is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 - if w1_scale_bias is None and is_mc2: - w1_scale = w1_scale.to(torch.float32) - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=w2_scale.dtype)[0] - else: - if w1_scale_bias is not None: - if group_list_type == 0: - group_list = torch.cat( - [group_list[:1], - torch.diff(group_list, dim=0)]) - group_list_type = 1 - bias1 = [w1_scale_bias] - bias2 = [w2_scale_bias] - # TODO w4a8 scene: dynamic acquisition of dtype in the future - _output_dtype = torch.bfloat16 - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - bias=bias2, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - return hidden_states - - -def unquant_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, - topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - w1 = w1.transpose(1, 2) - gate_up_out = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - if is_310p(): - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - - if topk_scales is not None: - gate_up_out *= topk_scales - - w2 = w2.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - return hidden_states - - -def unified_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - topk_scales: Optional[torch.Tensor] = None, - with_quant: bool = False) -> torch.Tensor: - if with_quant: - return quant_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) - else: - return unquant_apply_mlp(hidden_states=hidden_states, - w1=w1, - w2=w2, - group_list=group_list, - group_list_type=group_list_type, - topk_scales=topk_scales) - def unified_fused_experts_eager(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -742,24 +356,6 @@ class AscendFusedMoE(FusedMoE): self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None - if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( - self.quant_method, AscendUnquantizedFusedMoEMethod): - self.reduce_results = False - moe_dispatcher_config = ( - MoEDispatcherConfig().set_num_moe_experts( - self.global_num_experts).set_num_local_experts( - self.local_num_experts).set_moe_router_topk( - top_k).set_group_topk(topk_group). - set_num_groups(num_expert_group).set_expert_bias( - e_score_correction_bias).set_scaling_factor(1.0).build()) - self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) - if envs_ascend.VLLM_ASCEND_ENABLE_DBO: - token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) - self.token_dispatchers = [ - self.token_dispatcher, token_dispatcher1 - ] ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) diff --git a/vllm_ascend/ops/layers/moe_mlp.py b/vllm_ascend/ops/layers/moe_mlp.py new file mode 100644 index 0000000..c73e8ea --- /dev/null +++ b/vllm_ascend/ops/layers/moe_mlp.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from typing import Optional + +import torch +import torch_npu +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.utils import dispose_tensor, is_310p + + +def quant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 + if w1_scale_bias is None and is_mc2: + w1_scale = w1_scale.to(torch.float32) + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + else: + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], + torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + bias=bias2, + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + return hidden_states + + +def unquant_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + w1 = w1.transpose(1, 2) + gate_up_out = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + if topk_scales is not None: + gate_up_out *= topk_scales + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + return hidden_states + + +def unified_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + topk_scales: Optional[torch.Tensor] = None, + with_quant: bool = False) -> torch.Tensor: + if with_quant: + return quant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) + else: + return unquant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + group_list_type=group_list_type, + topk_scales=topk_scales) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index a5ca03a..855faad 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -29,434 +29,11 @@ import torch_npu from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.distributed.tensor_parallel import ( - all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, - all_to_all_sp2hp, gather_from_sequence_parallel_region, - reduce_scatter_last_dim_to_tensor_parallel_region) +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version - -class MoEDispatcherConfig: - - def __init__(self): - self.num_local_experts: int = 0 - self.num_moe_experts: int = 0 - self.moe_pad_expert_input_to_capacity: bool = False - self.moe_expert_capacity_factor: Optional[float] = None - self.moe_router_topk: int = 2 - self.moe_grouped_gemm: bool = False - self.group_topk: int = 0 - self.num_groups: int = 1 - self.expert_bias: torch.Tensor = None - self.scaling_factor: Optional[float] = None - self.is_fused: bool = True - - def set_num_local_experts(self, num_local_experts): - self.num_local_experts = num_local_experts - return self - - def set_num_moe_experts(self, num_moe_experts): - self.num_moe_experts = num_moe_experts - return self - - def set_moe_pad_expert_input_to_capacity(self, - moe_pad_expert_input_to_capacity): - self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity - return self - - def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor): - self.moe_expert_capacity_factor = moe_expert_capacity_factor - return self - - def set_moe_router_topk(self, moe_router_topk): - self.moe_router_topk = moe_router_topk - return self - - def set_moe_grouped_gemm(self, moe_grouped_gemm): - self.moe_grouped_gemm = moe_grouped_gemm - return self - - def set_group_topk(self, group_topk): - self.group_topk = group_topk - return self - - def set_num_groups(self, num_groups): - self.num_groups = num_groups - return self - - def set_expert_bias(self, expert_bias): - self.expert_bias = expert_bias - return self - - def set_scaling_factor(self, scaling_factor): - self.scaling_factor = scaling_factor - return self - - def set_is_fused(self, is_fused): - self.is_fused = is_fused - return self - - def build(self): - return self - - -class MoEDispatcher: - - def __init__(self, config: MoEDispatcherConfig) -> None: - """ - Initialize the MoE Token Dispatcher. - """ - self.config = config - self.shared_experts = None - - def set_shared_experts(self, shared_experts): - self.shared_experts = shared_experts - - @property - def ep_group(self): - """Get expert model parallel group.""" - return get_ep_group().device_group - - @property - def ep_rank(self): - return get_ep_group().rank_in_group - - @property - def ep_size(self): - return get_ep_group().world_size - - @property - def tp_ep_group(self): - """Get expert tensor and model parallel group.""" - return None - - @property - def tp_ep_size(self): - return 1 - - -class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): - overlap_stream = None - """ - The implementation of the AlltoAll-based token dispatcher, which handles token - dispatching on the sequence level instead of token level. The core of this implementation - lies in each device dispatching on the entire sequence, with the hidden state being partitioned. - - """ - - def __init__(self, config: MoEDispatcherConfig): - """ - Initialize the AlltoAllSeq token dispatcher. - - Args: - config (MoEDispatcherConfig): Configuration for the transformer model. - """ - super().__init__(config) - self.num_local_experts = config.num_local_experts - self.config = config - # use MOEAlltoAllSEQTokenDispatcher to init - - self.hidden_shape = None - self.num_input_tokens = None - self.num_experts = config.num_moe_experts - assert self.num_local_experts > 0, "Expected at least one expert" - if self.num_local_experts > 1: - self.expert_ids_per_ep_rank = torch.tensor( - [i % self.num_local_experts for i in range(self.num_experts)], - dtype=torch.int32, - device=torch.npu.current_device(), - ) - - local_expert_indices_offset = (self.ep_rank * self.num_local_experts) - - self.local_expert_indices = [ - local_expert_indices_offset + i - for i in range(self.num_local_experts) - ] - assert (len(self.local_expert_indices) == self.num_local_experts - ), "Invalid local expert indices" - for i in range(len(self.local_expert_indices) - 1): - assert (self.local_expert_indices[i] == - self.local_expert_indices[i + 1] - - 1), "local_expert_indices must be continuous" - self.probs = None - self.input_splits = None - self.output_splits = None - self.routing_map = None - self.hidden_shape_before_permute = None - - # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent - # to each local expert by all ranks. - self.num_global_tokens_per_local_expert_cpu = None - self.num_global_tokens_per_local_expert = None - - # A cuda stream synchronization is needed in self.token_permutation() - # in some cases, because there are several non-blocking DtoH data - # transfers called in self.preprocess(). The synchronization happens - # at different points based on MoE settings as late as possible. - # Valid sync points are "before_permutation_1", "before_ep_alltoall", - # "before_finish", and "no_sync". - self.device_sync_point = "no_sync" - - # cached intermediate tensors. - self.cached_permutated_local_input_tokens = None - self.cached_global_input_tokens = None - self.cached_shared_expert_output = None - self.tokens_per_expert = None - self.perm1_finish_event = None - self.global_input_tokens_local_experts_indices = None - - if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: - MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() - - self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream - - def preprocess(self, - indices: torch.Tensor, - with_sync=True) -> torch.Tensor: - """ - Preprocess routing map for AlltoAll communication and token permutation. - This method computes the number of tokens assigned to each expert based on - the routing map. It also initializes the necessary data structures for - AlltoAll communication, such as input and output splits, and the mapping - between global tokens and local experts. - - Args: - routing_map (torch.Tensor): The mapping of tokens to experts, with shape - [num_tokens, num_experts]. - - Returns: - torch.Tensor: Tensor containing the number of tokens assigned to local expert. - """ - num_local_tokens_per_expert = torch.histc(indices, - bins=self.num_experts, - min=0, - max=self.num_experts) - - # num_local_tokens_per_expert: [num_experts] - - ep_size = self.ep_size - - # Dropless - self.num_out_tokens = indices.numel() - if self.ep_size > 1 or self.num_local_experts > 1: - # Token dropless and enable ep. A synchronization is needed before expert parallel - # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. - self.device_sync_point = "before_ep_alltoall" - else: - # Token dropless and no ep. A synchronization is needed to get the - # `tokens_per_expert` CPU value. - self.device_sync_point = "before_finish" - - if ep_size > 1: - # =================================================== - # Calculate input_splits, output_splits for alltoall-v. - # =================================================== - self.input_splits = (num_local_tokens_per_expert.reshape( - ep_size, self.num_local_experts).sum(axis=1).to( - torch.device("cpu"), non_blocking=True).numpy()) - num_global_tokens_per_expert = gather_from_sequence_parallel_region( - num_local_tokens_per_expert, - group=self.ep_group).reshape(ep_size, self.num_experts) - self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ - 0]:self.local_expert_indices[-1] + 1] - if self.num_global_tokens_per_local_expert is None: - raise ValueError( - "num_global_tokens_per_local_expert must be set before sum." - ) - self.output_splits = (self.num_global_tokens_per_local_expert.sum( - axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) - num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( - axis=0) - # =================================================== - # num_global_tokens_per_expert: [ep_size, num_experts] - # num_global_tokens_per_local_expert: [ep_size, num_local_experts] - # num_tokens_per_local_expert: [num_local_experts] - # =================================================== - else: - self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( - -1, self.num_experts) - num_tokens_per_local_expert = num_local_tokens_per_expert - - if self.num_local_experts > 1 and with_sync: - if self.num_global_tokens_per_local_expert is None: - raise ValueError( - "num_global_tokens_per_local_expert must be set before operations." - ) - self.device_sync_point = "no_sync" - self.global_input_tokens_local_experts_indices = torch.repeat_interleave( - self.expert_ids_per_ep_rank, - self.num_global_tokens_per_local_expert.ravel()) - - return num_tokens_per_local_expert - - def token_permutation( - self, - hidden_states: torch.Tensor, - probs: torch.Tensor, - routing_map: torch.Tensor, - ): - """ - Dispatch tokens to local experts using AlltoAllSeq communication. - - Args: - hidden_states (torch.Tensor): Input token embeddings. - probs (torch.Tensor): Probs of tokens assigned to experts. - Shape: [num_tokens, num_experts]. - routing_map (torch.Tensor): Mapping of tokens assigned to experts. - Shape: [num_tokens, num_experts]. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - Permuted token embeddings for local experts. - - Number of tokens per expert. - """ - self.hidden_shape = hidden_states.shape - self.probs = probs - self.top_indices = routing_map - assert probs.dim() == 2, "Expected 2D tensor for probs" - assert routing_map.dim() == 2, "Expected 2D tensor for routing map" - - # Permutation 1: input to AlltoAll input - def alltoall_token_permutation1(hidden_states, routing_map): - assert self.hidden_shape is not None - hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) - tokens_per_expert = self.preprocess(routing_map) - if self.tp_ep_size > 1: - hidden_states = all_to_all_sp2hp(hidden_states, - group=self.tp_ep_group) - self.hidden_shape_before_permute = hidden_states.shape - - if self.device_sync_point == "before_permutation_1": - torch.npu.current_stream().synchronize() - - permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( - tokens=hidden_states, - indices=self.top_indices, - num_out_tokens=self.num_out_tokens, - ) - return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert - - permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( - hidden_states, routing_map) - self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping - # permute 1 - - ep_group = self.ep_group - - # Perform expert parallel AlltoAll communication - if self.device_sync_point == "before_ep_alltoall": - torch.npu.current_stream().synchronize() - _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( - permutated_local_input_tokens, - self.output_splits, - self.input_splits, - ep_group, - ) - - # shared experts compute - if self.shared_experts is not None: - (share_experts_output), *_ = self.shared_experts(hidden_states) - else: - share_experts_output = None - - permute1_ep_all_to_all_handle.wait() - permutated_local_input_tokens.untyped_storage().resize_(0) - - def alltoall_token_permutation2(global_input_tokens): - # Permutation 2: Sort tokens by local expert. - if self.num_local_experts > 1: - global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( - global_input_tokens, - self.global_input_tokens_local_experts_indices) - - # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. - # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] - if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: - global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( - global_input_tokens, self.tp_ep_group) - if self.device_sync_point == "before_finish": - torch.npu.current_stream().synchronize() - - return global_input_tokens - - # token premute2 input - global_input_tokens = alltoall_token_permutation2(global_input_tokens) - - return share_experts_output, global_input_tokens, tokens_per_expert - - def token_unpermutation(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): - """ - Reverse the token permutation to restore the original order. - - Args: - hidden_states (torch.Tensor): Output from local experts. - bias (torch.Tensor, optional): Bias tensor (not supported). - - Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: - - Unpermuted token embeddings in the original order. - - None (bias is not supported). - """ - - def alltoall_token_unpermutation1(hidden_states): - assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" - # Perform tensor parallel Reduce-Scatter - # hidden_states: [SEQL, H] -> [SEQL, H/TP] - if self.tp_ep_size > 1: - hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region( - hidden_states, group=self.tp_ep_group) - - # Unpermutation 2: expert output to AlltoAll input - if hidden_states.shape[0] > 0 and self.num_local_experts > 1: - hidden_states = torch_npu.npu_moe_token_unpermute( - hidden_states, - self.reversed_global_input_permutation_mapping) - - return hidden_states - - hidden_states = alltoall_token_unpermutation1(hidden_states) - - ep_group = self.ep_group - # Perform expert parallel AlltoAll communication - # hidden_states: [SEQL, H] -> [SEQL, H/TP] - _, permutated_local_input_tokens, handle = async_all_to_all( - hidden_states, self.input_splits, self.output_splits, ep_group) - handle.wait() - hidden_states.untyped_storage().resize_(0) - - def alltoall_token_unpermutation2(permutated_local_input_tokens): - # Unpermutation 1: AlltoAll output to output - - output = torch_npu.npu_moe_token_unpermute( - permuted_tokens=permutated_local_input_tokens, - sorted_indices=self.reversed_local_input_permutation_mapping. - to(torch.int32), - probs=self.probs, - restore_shape=self.hidden_shape_before_permute) - - # Perform tensor parallel AlltoAll communication - # output: [S*B, H/TP] -> [S*B/TP, H] - if self.tp_ep_size > 1: - output = all_to_all_hp2sp(output, self.tp_ep_group) - - # Reshape the output tensor - output = output.view(self.hidden_shape) - return output - - output = alltoall_token_unpermutation2(permutated_local_input_tokens) - - self.input_splits = None - self.output_splits = None - self.num_global_tokens_per_local_expert = None - self.num_global_tokens_per_local_expert_cpu = None - - return output, None - - _Dispatchers: Dict[str, Any] = {} @@ -1090,7 +667,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): - assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" + assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." hidden_states = self._combine_preprocess(hidden_states) diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 72203cd..42e8659 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -38,15 +38,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, @@ -54,74 +51,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, get_rm_router_logits_state, is_310p) -MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER - - -def torchair_process_topk_ids(topk_ids: torch.Tensor, expert_num: int, - ep_size: int, max_row_per_ep_rank: int, - num_tokens: int, - top_k: int) -> tuple[torch.Tensor, torch.Tensor]: - original_total_elements = num_tokens * top_k - device = topk_ids.device - original_dtype = topk_ids.dtype - - if original_total_elements == 0: - output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) - unpad_indices = torch.full((original_total_elements, ), - -1, - dtype=torch.long, - device=device) - return topk_ids_pad, unpad_indices - - experts_per_ep_rank_val = expert_num // ep_size - if experts_per_ep_rank_val == 0: - raise ValueError( - "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " - "Ensure expert_num >= ep_size.") - - assigned_ep_rank = (topk_ids.float() / - experts_per_ep_rank_val).to(original_dtype) - indices_arange = torch.arange(topk_ids.shape[0], device=device) - - is_new_segment = torch.cat( - (torch.tensor([True], device=device), assigned_ep_rank[1:] - != assigned_ep_rank[:-1])) - temp_start_markers = torch.full_like(indices_arange, - -1, - dtype=indices_arange.dtype) - temp_start_markers[is_new_segment] = indices_arange[is_new_segment] - start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] - token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token - is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank - cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) - indices_in_rec_cond_list_for_all = cumsum_kept - 1 - unpad_indices = torch.where( - is_kept_mask, indices_in_rec_cond_list_for_all, - torch.tensor(-1, device=device, dtype=torch.long)) - output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) - if topk_ids.shape[0] > 0: - all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - temp_pad_buffer = torch.full((output_len + 1, ), - expert_num, - dtype=original_dtype, - device=device) - output_len_tensor = torch.tensor(output_len, - dtype=torch.long, - device=device) - scatter_indices = torch.where(is_kept_mask, all_destination_indices, - output_len_tensor) - temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) - topk_ids_pad = temp_pad_buffer[:output_len] - return topk_ids_pad, unpad_indices - def torchair_fused_experts_with_mc2( hidden_states: torch.Tensor, @@ -459,130 +388,6 @@ def torchair_fused_experts_with_all2all( return final_hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. -def torchair_fused_experts_with_all2all_buffer( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - max_model_len: int, - global_batch_size: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - device = hidden_states.device - - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, - device=device).view(top_k, - -1).permute(1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * - max_model_len // ep_group.world_size + - 1) * top_k * 2 - expert_idx_buffer_scatter, unpad_indices = torchair_process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) - hidden_states_pad_idx = torch.zeros( - expert_idx_buffer_scatter.shape, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[expert_idx_buffer_scatter != - global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) - - hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] - expert_idx_buffer_gather = torch.empty_like( - expert_idx_buffer_scatter, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - hidden_states_buffer_gather = torch.empty_like( - hidden_states_buffer_scatter, - dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) - dist.all_to_all_single(expert_idx_buffer_gather, - expert_idx_buffer_scatter, - group=ep_group.device_group) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) - mask = expert_idx_buffer_gather != global_num_experts - local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) - hidden_states = hidden_states_buffer_gather[mask] - idx_type = local_expert_idx.dtype - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) - sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - hidden_states = hidden_states[sorted_idx] - group_list_type = 0 - - hidden_states = torchair_apply_mlp(hidden_states, - w1, - w2, - expert_tokens, - group_list_type=group_list_type) - - resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) - hidden_states = hidden_states[resorted_idx] - hidden_states_scatter = torch.zeros( - (mask.shape[0], hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states_scatter[mask] = hidden_states - hidden_states_gatter = torch.empty_like( - hidden_states_scatter, - dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) - dist.all_to_all_single(hidden_states_gatter, - hidden_states_scatter, - group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != - global_num_experts] - if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states[unpad_indices != -1] = hidden_states_gatter - else: - # TODO: Reorder device memory 2 times here, replace the current - hidden_states = hidden_states_gatter - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - def torchair_fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -674,25 +479,6 @@ def torchair_fused_experts_moge( return final_hidden_states -def torchair_fused_experts_with_all2allv( - token_dispatcher, - probs, - routing_map, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, -): - # Enable moe alltoallv, it's a balanced policy for precision and efficiency. - (share_experts_output, dispatched_input, - tokens_per_expert) = (token_dispatcher.token_permutation( - hidden_states, probs, routing_map)) - - expert_output = torchair_apply_mlp(dispatched_input, w1, w2, - tokens_per_expert) - output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) - return output - - def torchair_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1120,28 +906,6 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) - elif MOE_ALL2ALL_BUFFER: - return torchair_fused_experts_with_all2all_buffer( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - max_model_len=self.max_model_len, - global_batch_size=self.global_batch_size, - expert_map=expert_map, - ep_group=get_ep_group()) - elif fused_moe_state == FusedMoEState.All2AllSeq: - token_dispatcher = kwargs.get("token_dispatcher") - return torchair_fused_experts_with_all2allv( - token_dispatcher=token_dispatcher, - probs=topk_weights, - routing_map=topk_ids, - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - ) else: return torchair_fused_experts_with_all2all( hidden_states=x, @@ -1315,25 +1079,6 @@ class TorchairAscendFusedMoE(FusedMoE): # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) - self.token_dispatcher = None - if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( - self.quant_method, TorchairAscendUnquantizedFusedMoEMethod): - self.reduce_results = False - moe_dispatcher_config = ( - MoEDispatcherConfig().set_num_moe_experts( - self.global_num_experts).set_num_local_experts( - self.local_num_experts).set_moe_router_topk( - top_k).set_group_topk(topk_group). - set_num_groups(num_expert_group).set_expert_bias( - e_score_correction_bias).set_scaling_factor(1.0).build()) - self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) - if envs_ascend.VLLM_ASCEND_ENABLE_DBO: - token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) - self.token_dispatchers = [ - self.token_dispatcher, token_dispatcher1 - ] def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -1486,7 +1231,6 @@ class TorchairAscendFusedMoE(FusedMoE): shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, mc2_mask=mc2_mask, - token_dispatcher=self.token_dispatcher, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, )