[Refactor] [MoE] Rename moe-related classes & files (#3646)

### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-25 11:22:03 +08:00
committed by GitHub
parent 0637e8f021
commit 63c363d3de
25 changed files with 183 additions and 199 deletions

View File

@@ -20,7 +20,7 @@ import torch
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.ops.moe.comm_utils import (
from vllm_ascend.ops.fused_moe.comm_utils import (
_gather_along_first_dim, async_all_to_all,
gather_from_sequence_parallel_region)

View File

@@ -1,56 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -24,9 +24,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import (
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp)
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch(
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture
@@ -105,37 +108,37 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None
)), \
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=mock_forward_context_obj):
yield {
@@ -319,8 +322,8 @@ class TestCumsumGroupList(TestBase):
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
@@ -384,7 +387,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -426,7 +429,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -486,7 +489,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
@@ -531,7 +534,7 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
@@ -595,3 +598,39 @@ class TestUnifiedApplyMLP(TestBase):
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
class TestLoadWeight(TestBase):
def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

View File

@@ -4,8 +4,9 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from tests.ut.base import TestBase
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl)
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
class TestMoECommMethod(TestBase):
@@ -24,12 +25,14 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
@@ -72,12 +75,11 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,
mock_get_current_vllm_config):
@@ -121,12 +123,14 @@ class TestMoECommMethod(TestBase):
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
)
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context,
@@ -163,13 +167,15 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
@patch(
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context,

View File

@@ -4,13 +4,12 @@ from unittest.mock import MagicMock, patch
import torch
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
class TestPrepareAndFinalize(unittest.TestCase):
def setUp(self):
# Mock FusedMoEConfig
@@ -24,14 +23,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.original_num_experts = 8
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=1)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
mock_tp_size):
mock_context = MagicMock()
@@ -39,7 +36,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
layer = PrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
@@ -59,14 +56,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(result.shape[0], 3)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=2)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
@patch("torch.distributed.all_gather")
def test_mc2_tp_split_allgather(self, mock_all_gather,
mock_get_forward_context, mock_tp_rank,
@@ -76,7 +71,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_context.padded_num_tokens = 4
mock_get_forward_context.return_value = mock_context
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
layer = PrepareAndFinalizeWithMC2(self.moe_config)
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
@@ -108,13 +103,13 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(final_result.shape[0], 4)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=1)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
@@ -130,15 +125,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(result.shape[0], 3)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
return_value=2)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
return_value=0)
@patch("torch.distributed.all_gather")
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
mock_tp_size):
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
@@ -169,14 +164,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_allgather_prepare_finalize(self, mock_get_forward_context,
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
return_value=False)
def test_allgather_prepare_finalize(self, mock_enable_sp,
mock_get_forward_context,
mock_tp_all_reduce, mock_get_dp_group):
# Mock forward context
mock_context = MagicMock()
@@ -198,7 +194,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.ep_size = 1
self.moe_config.dp_group = mock_dp_group
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
layer = PrepareAndFinalizeWithAllGather(self.moe_config)
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
@@ -232,13 +228,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
)
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce,
mock_get_dp_group):
@@ -266,7 +260,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
# Local inputs
hidden_states = torch.randn(3, 8)

View File

@@ -21,7 +21,7 @@ import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
patcher6 = patch(
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)