[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -28,9 +28,10 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
@@ -182,7 +183,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
):
|
||||
context_mock = MagicMock()
|
||||
context_mock.fused_moe_state = 0
|
||||
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context",
|
||||
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context",
|
||||
return_value=context_mock):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||
@@ -282,9 +283,9 @@ def test_select_experts(
|
||||
dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
|
||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk, \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
@@ -318,7 +319,7 @@ def test_select_experts(
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
|
||||
pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
|
||||
@@ -90,9 +90,9 @@ def mock_distributed():
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
_gather_along_first_dim, async_all_to_all,
|
||||
gather_from_sequence_parallel_region)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -24,9 +24,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (
|
||||
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch(
|
||||
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -105,37 +108,37 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
@@ -319,8 +322,8 @@ class TestCumsumGroupList(TestBase):
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -384,7 +387,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -426,7 +429,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -486,7 +489,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -531,7 +534,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@@ -595,3 +598,39 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
@@ -4,8 +4,9 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -24,12 +25,14 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -72,12 +75,11 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
@@ -121,12 +123,14 @@ class TestMoECommMethod(TestBase):
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
@@ -163,13 +167,15 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
@patch(
|
||||
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
|
||||
@@ -4,13 +4,12 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
class TestPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
@@ -24,14 +23,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.original_num_experts = 8
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
@@ -39,7 +36,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -59,14 +56,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
@@ -76,7 +71,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
@@ -108,13 +103,13 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
@@ -130,15 +125,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
@@ -169,14 +164,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_allgather_prepare_finalize(self, mock_get_forward_context,
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp",
|
||||
return_value=False)
|
||||
def test_allgather_prepare_finalize(self, mock_enable_sp,
|
||||
mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
@@ -198,7 +194,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = mock_dp_group
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
@@ -232,13 +228,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
"vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context")
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
@@ -266,7 +260,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -369,7 +369,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
patcher6 = patch(
|
||||
'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -377,7 +378,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
@@ -758,7 +758,7 @@ class TestSelectExperts(TestBase):
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
patcher = patch(
|
||||
'vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=self.mock_ctx)
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
@@ -831,7 +831,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
|
||||
Reference in New Issue
Block a user