[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
_gather_along_first_dim, async_all_to_all,
|
||||
gather_from_sequence_parallel_region)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -24,9 +24,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (
|
||||
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch(
|
||||
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -105,37 +108,37 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
@@ -319,8 +322,8 @@ class TestCumsumGroupList(TestBase):
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -384,7 +387,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -426,7 +429,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -486,7 +489,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -531,7 +534,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@@ -595,3 +598,39 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -4,8 +4,9 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -24,12 +25,14 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -72,12 +75,11 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
@@ -121,12 +123,14 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -163,13 +167,15 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
|
||||
@@ -4,13 +4,12 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
class TestPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
@@ -24,14 +23,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.original_num_experts = 8
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
@@ -39,7 +36,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -59,14 +56,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
@@ -76,7 +71,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
@@ -108,13 +103,13 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
@@ -130,15 +125,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
@@ -169,14 +164,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_allgather_prepare_finalize(self, mock_get_forward_context,
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
@@ -198,7 +194,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = mock_dp_group
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -232,13 +228,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
@@ -266,7 +260,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
patcher6 = patch(
|
||||
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
|
||||
Reference in New Issue
Block a user