[Refactor] [MoE] Rename moe-related classes & files (#3646)

### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-25 11:22:03 +08:00
committed by GitHub
parent 0637e8f021
commit 63c363d3de
25 changed files with 183 additions and 199 deletions

View File

@@ -28,9 +28,10 @@ import torch
import torch_npu import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather from vllm_ascend.ops.fused_moe.token_dispatcher import \
TokenDispatcherWithAllGather
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1] EP_SIZE = [1]
@@ -182,7 +183,7 @@ def test_token_dispatcher_with_all_gather_quant(
): ):
context_mock = MagicMock() context_mock = MagicMock()
context_mock.fused_moe_state = 0 context_mock.fused_moe_state = 0
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context", with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
return_value=context_mock): return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10 a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
@@ -282,9 +283,9 @@ def test_select_experts(
dtype=torch.int32) dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids) custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \ ) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())): return_value=MagicMock(weight_prefetch_method=MagicMock())):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x) x)
@@ -318,7 +319,7 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE) @pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str): def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())), \ return_value=MagicMock(weight_prefetch_method=MagicMock())), \
pytest.raises(ValueError, pytest.raises(ValueError,
match="Unsupported scoring function: invalid"): match="Unsupported scoring function: invalid"):

View File

@@ -90,9 +90,9 @@ def mock_distributed():
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \ _PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \

View File

@@ -20,7 +20,7 @@ import torch
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from tests.ut.base import PytestBase from tests.ut.base import PytestBase
from vllm_ascend.ops.moe.comm_utils import ( from vllm_ascend.ops.fused_moe.comm_utils import (
_gather_along_first_dim, async_all_to_all, _gather_along_first_dim, async_all_to_all,
gather_from_sequence_parallel_region) gather_from_sequence_parallel_region)

View File

@@ -1,56 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -24,9 +24,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.fused_moe import (
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp)
from vllm_ascend.utils import AscendSocVersion, adapt_patch from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True) adapt_patch(True)
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048 mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config', mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config) return_value=mock_vllm_config)
mocker.patch(
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture @pytest.fixture
@@ -105,37 +108,37 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \ return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config', patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
return_value=MagicMock( return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False), torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False, enable_multistream_moe=False,
expert_map_path=None expert_map_path=None
)), \ )), \
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map', patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.common_fused_moe.get_forward_context', patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=mock_forward_context_obj): return_value=mock_forward_context_obj):
yield { yield {
@@ -319,8 +322,8 @@ class TestCumsumGroupList(TestBase):
class TestUnifiedApplyMLP(TestBase): class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant') @patch('torch_npu.npu_dequant_swiglu_quant')
@@ -384,7 +387,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dynamic_quant')
@@ -426,7 +429,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16) self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dynamic_quant')
@@ -486,7 +489,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dynamic_quant')
@@ -531,7 +534,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16) self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context") @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu") @patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant") @patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@@ -595,3 +598,39 @@ class TestUnifiedApplyMLP(TestBase):
self.assertTrue(mock_forward_context.with_quant) self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -4,8 +4,9 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl) AlltoAllCommImpl,
MC2CommImpl)
class TestMoECommMethod(TestBase): class TestMoECommMethod(TestBase):
@@ -24,12 +25,14 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock() self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0 self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch( @patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
) )
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher, def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize, mock_prepare_finalize,
mock_get_forward_context, mock_get_forward_context,
@@ -72,12 +75,11 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata) context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch( @patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
) @patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context, mock_get_forward_context,
mock_get_current_vllm_config): mock_get_current_vllm_config):
@@ -121,12 +123,14 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata) context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch( @patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
) )
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher, def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize, mock_prepare_finalize,
mock_get_forward_context, mock_get_forward_context,
@@ -163,13 +167,15 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None) hidden_states, router_logits, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch( @patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
) )
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") @patch(
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp") "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp, def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context, mock_get_forward_context,

View File

@@ -4,13 +4,12 @@ from unittest.mock import MagicMock, patch
import torch import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
class TestFusedMoEPrepareAndFinalize(unittest.TestCase): class TestPrepareAndFinalize(unittest.TestCase):
def setUp(self): def setUp(self):
# Mock FusedMoEConfig # Mock FusedMoEConfig
@@ -24,14 +23,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.original_num_experts = 8 self.moe_config.original_num_experts = 8
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=1) return_value=1)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
@patch( @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
mock_tp_size): mock_tp_size):
mock_context = MagicMock() mock_context = MagicMock()
@@ -39,7 +36,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_context.padded_num_tokens = 4 mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) layer = PrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
@@ -59,14 +56,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=2) return_value=2)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
@patch( @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
@patch("torch.distributed.all_gather") @patch("torch.distributed.all_gather")
def test_mc2_tp_split_allgather(self, mock_all_gather, def test_mc2_tp_split_allgather(self, mock_all_gather,
mock_get_forward_context, mock_tp_rank, mock_get_forward_context, mock_tp_rank,
@@ -76,7 +71,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_context.padded_num_tokens = 4 mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) layer = PrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(4, 8) hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2) router_logits = torch.randn(4, 2)
@@ -108,13 +103,13 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(final_result.shape[0], 4) self.assertEqual(final_result.shape[0], 4)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=1) return_value=1)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size): def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) layer = PrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
@@ -130,15 +125,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=2) return_value=2)
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0) return_value=0)
@patch("torch.distributed.all_gather") @patch("torch.distributed.all_gather")
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank, def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
mock_tp_size): mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) layer = PrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(2, 8) hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2) router_logits = torch.randn(2, 2)
@@ -169,14 +164,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
# Should concat back # Should concat back
self.assertEqual(final_result.shape[0], 2) self.assertEqual(final_result.shape[0], 2)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" "vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
) )
@patch( @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" @patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
) return_value=False)
def test_allgather_prepare_finalize(self, mock_get_forward_context, def test_allgather_prepare_finalize(self, mock_enable_sp,
mock_get_forward_context,
mock_tp_all_reduce, mock_get_dp_group): mock_tp_all_reduce, mock_get_dp_group):
# Mock forward context # Mock forward context
mock_context = MagicMock() mock_context = MagicMock()
@@ -198,7 +194,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.ep_size = 1 self.moe_config.ep_size = 1
self.moe_config.dp_group = mock_dp_group self.moe_config.dp_group = mock_dp_group
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) layer = PrepareAndFinalizeWithAllGather(self.moe_config)
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
@@ -232,13 +228,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
result_with_tp = layer.finalize(h_out, reduce_results=True) result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3) self.assertEqual(result_with_tp.shape[0], 3)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" "vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
) )
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context, def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce, mock_tp_all_reduce,
mock_get_dp_group): mock_get_dp_group):
@@ -266,7 +260,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.tp_size = 1 self.moe_config.tp_size = 1
self.moe_config.ep_size = 1 self.moe_config.ep_size = 1
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
# Local inputs # Local inputs
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)

View File

@@ -21,7 +21,7 @@ import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV, AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2) TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.mc2_group.rank_in_group = 0 self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8 self.mc2_group.world_size = 8
self.mc2_group_patch = patch( self.mc2_group_patch = patch(
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group", "vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group) return_value=self.mc2_group)
self.mc2_group_patch.start() self.mc2_group_patch.start()
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
# Mock get_ascend_soc_version() # Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch( self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", "vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3) return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start() self.ascend_soc_version_patch.start()
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all # Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all') patcher6 = patch(
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start() self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop) self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
# Mock gather_from_sequence_parallel_region # Mock gather_from_sequence_parallel_region
patcher7 = patch( patcher7 = patch(
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region' 'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
) )
self.mock_gather_from_sequence_parallel_region = patcher7.start() self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop) self.addCleanup(patcher7.stop)

View File

@@ -5,8 +5,8 @@ import torch
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk, from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk,
select_experts) select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod, AscendW8A8LinearMethod,
@@ -758,7 +758,7 @@ class TestSelectExperts(TestBase):
self.mock_ctx = MagicMock() self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock()
patcher = patch( patcher = patch(
'vllm_ascend.ops.moe.experts_selector.get_forward_context', 'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=self.mock_ctx) return_value=self.mock_ctx)
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
patcher.start() patcher.start()
@@ -831,7 +831,7 @@ class TestSelectExperts(TestBase):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32) self.assertEqual(ids.dtype, torch.int32)
@patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk') @patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias""" """Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens, mock_grouped_topk.return_value = torch.ones(self.num_tokens,

View File

@@ -87,7 +87,8 @@ def set_ascend_forward_context(
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)

View File

@@ -66,7 +66,7 @@ from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase from vllm_ascend.ops.linear import AscendLinearBase
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is

View File

@@ -17,7 +17,7 @@
import torch import torch
import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa

View File

@@ -35,8 +35,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map) determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch, is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled, shared_expert_dp_enabled,

View File

@@ -24,15 +24,13 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
FusedMoEPrepareAndFinalizeWithAll2All, from vllm_ascend.ops.fused_moe.prepare_finalize import (
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
FusedMoEPrepareAndFinalizeWithNaiveMulticast) PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.token_dispatcher import (
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, TokenDispatcherWithMoge)
TokenDispatcherWithMC2,
TokenDispatcherWithMoge)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -59,8 +57,7 @@ class MoECommMethod(ABC):
self.moe_config = moe_config self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher() self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( self.prepare_finalize = self._get_prepare_finalize()
)
def prepare( def prepare(
self, self,
@@ -71,7 +68,7 @@ class MoECommMethod(ABC):
gate=None gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare( hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, gate) replace_allreduce, gate)
return hidden_states, router_logits, mc2_mask, context_metadata return hidden_states, router_logits, mc2_mask, context_metadata
@@ -80,8 +77,9 @@ class MoECommMethod(ABC):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
reduce_results: bool, reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor: context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize( hidden_states = self.prepare_finalize.finalize(hidden_states,
hidden_states, reduce_results, context_metadata) reduce_results,
context_metadata)
return hidden_states return hidden_states
def fused_experts( def fused_experts(
@@ -169,9 +167,9 @@ class MoECommMethod(ABC):
"_get_token_dispatcher function not implemented.") "_get_token_dispatcher function not implemented.")
@abstractmethod @abstractmethod
def _get_fused_moe_prepare_finalize(self): def _get_prepare_finalize(self):
raise NotImplementedError( raise NotImplementedError(
"_get_fused_moe_prepare_finalize function not implemented.") "_get_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod): class AllGatherCommImpl(MoECommMethod):
@@ -205,8 +203,8 @@ class AllGatherCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts, num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self): def _get_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) return PrepareAndFinalizeWithAllGather(self.moe_config)
class MC2CommImpl(MoECommMethod): class MC2CommImpl(MoECommMethod):
@@ -222,8 +220,8 @@ class MC2CommImpl(MoECommMethod):
def _get_token_dispatcher(self): def _get_token_dispatcher(self):
return TokenDispatcherWithMC2() return TokenDispatcherWithMC2()
def _get_fused_moe_prepare_finalize(self): def _get_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) return PrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod): class AlltoAllCommImpl(MoECommMethod):
@@ -242,8 +240,8 @@ class AlltoAllCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts, num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self): def _get_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) return PrepareAndFinalizeWithAll2All(self.moe_config)
class NaiveMulticastCommImpl(MoECommMethod): class NaiveMulticastCommImpl(MoECommMethod):
@@ -271,5 +269,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts, num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self): def _get_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)

View File

@@ -30,7 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC): class PrepareAndFinalize(ABC):
""" """
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
in distributed environments. Subclasses implement specific communication strategies in distributed environments. Subclasses implement specific communication strategies
@@ -103,7 +103,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError("Finalize function not implemented.") raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
""" """
MoE communication strategy using All-to-All style slicing. MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
@@ -195,7 +195,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
return hidden_states return hidden_states
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
""" """
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method. All2All and share the same finalize method.
@@ -275,7 +275,7 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
return hidden_states, router_logits, mc2_mask, context_metadata return hidden_states, router_logits, mc2_mask, context_metadata
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
""" """
MoE communication strategy using All-Gather + Reduce-Scatter on EP group. MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
There are two sets of prepare and finalize: There are two sets of prepare and finalize:
@@ -429,7 +429,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
return hidden_states return hidden_states
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
""" """
MoE communication strategy using Naive Multicast (point-to-point broadcast). MoE communication strategy using Naive Multicast (point-to-point broadcast).
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.

View File

@@ -28,7 +28,7 @@ import torch_npu
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.comm_utils import ( from vllm_ascend.ops.fused_moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region) async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
is_hierarchical_communication_enabled) is_hierarchical_communication_enabled)

View File

@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable) oproj_tp_enable)

View File

@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz

View File

@@ -24,7 +24,7 @@ from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz

View File

@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
vllm_version_is) vllm_version_is)

View File

@@ -538,8 +538,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE) AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
AscendQuantRMSNorm, AscendRMSNorm) AscendQuantRMSNorm, AscendRMSNorm)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear, from vllm_ascend.ops.linear import (AscendColumnParallelLinear,