[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)

### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-16 11:06:00 +08:00
committed by GitHub
parent 0aba644633
commit 18ca7861f6
18 changed files with 523 additions and 596 deletions

View File

@@ -23,8 +23,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):

View File

@@ -1,22 +0,0 @@
import os
import unittest
from unittest import mock
from vllm_ascend.ascend_forward_context import get_dispatcher_name
class TestGetDispatcherName(unittest.TestCase):
def test_get_dispatcher_name(self):
result = get_dispatcher_name(1, False)
assert result == "TokenDispatcherWithAllGather"
result = get_dispatcher_name(4, False)
assert result == "TokenDispatcherWithAll2AllV"
result = get_dispatcher_name(16, True)
assert result == "TokenDispatcherWithAll2AllV"
result = get_dispatcher_name(16, False)
assert result == "TokenDispatcherWithMC2"
with mock.patch.dict(os.environ,
{"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}):
result = get_dispatcher_name(16, False)
assert result == "TokenDispatcherWithAllGather"

View File

@@ -6,7 +6,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
@@ -216,3 +217,68 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce,
mock_get_dp_group):
# Mock forward context with DP metadata
mock_context = MagicMock()
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[2, 5, 7])
mock_get_forward_context.return_value = mock_context
# Setup DP group mock
mock_dp_group = MagicMock()
mock_dp_group.broadcast = MagicMock()
mock_dp_group.all_reduce = MagicMock()
mock_get_dp_group.return_value = mock_dp_group
# Mock all_reduce to just return input (simulate sum)
def mock_all_reduce(tensor):
return tensor * 2
mock_dp_group.all_reduce.side_effect = mock_all_reduce
# Setup config
self.moe_config.dp_size = 3
self.moe_config.dp_rank = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
# Local inputs
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock gate for router logits recomputation
mock_gate = MagicMock()
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))
self.assertEqual(r_out.shape, (7, 2))
# Run finalize
result = layer.finalize(h_out, reduce_results=False)
# Should slice back to local: [3, 8]
self.assertEqual(result.shape, (3, 8))
# Test with reduce_results=True and TP/EP > 1
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape, (3, 8))

View File

@@ -22,10 +22,7 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_moe_comm_method = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_moe_comm_method.prepare.side_effect = mock_prepare
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_fused_experts_result = torch.randn(16, 2)
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
def mock_finalize(hidden_states, **kwargs):
return hidden_states
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_moe_comm_method.finalize.side_effect = mock_finalize
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -131,14 +84,12 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -150,6 +101,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
@@ -157,22 +110,20 @@ def mock_dist_env(mocker: MockerFixture):
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
'mock_moe_comm_method': mock_moe_comm_method,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
@@ -338,9 +289,7 @@ class TestAscendFusedMoe:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
@@ -394,25 +343,10 @@ class TestAscendUnquantizedFusedMoEMethod:
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
@@ -438,35 +372,22 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts=global_num_experts,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
@@ -493,8 +414,10 @@ class TestAscendUnquantizedFusedMoEMethod:
expert_map=expert_map,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@@ -574,7 +497,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
@@ -618,8 +541,6 @@ class TestUnifiedApplyMLP(TestBase):
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()

View File

@@ -23,8 +23,7 @@ from tests.ut.base import TestBase
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
class TestTokenDispatcherWithMC2(TestBase):
@@ -521,99 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -21,37 +21,31 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
# yapf: disable
@pytest.mark.parametrize(
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method",
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendSocVersion.A2, False, 8, 100, 256, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, "allgather"),
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
# Case 2: A2 SOC
# 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16)
(AscendSocVersion.A2, True, 16, 100, 256, "mc2"),
(AscendSocVersion.A2, True, 32, 256, 256, "mc2"),
# 2.2: MC2 token capacity exceeded
(AscendSocVersion.A2, True, 16, 257, 256, "allgather"),
# 2.3: MC2 world size not met
(AscendSocVersion.A2, True, 8, 100, 256, "allgather"),
(AscendSocVersion.A2, True, 15, 100, 256, "allgather"),
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
# Case 3: A3 SOC
# 3.1: MC2 condition met (tokens <= capacity)
(AscendSocVersion.A3, True, 8, 100, 256, "mc2"),
(AscendSocVersion.A3, True, 16, 256, 256, "mc2"),
# 3.2: MC2 token capacity exceeded
(AscendSocVersion.A3, True, 8, 257, 256, "alltoall"),
(AscendSocVersion.A3, True, 16, 500, 256, "alltoall"),
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
# Case 4: A3 SOC
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
world_size, num_tokens, mc2_tokens_capacity,
expected_method):
quant_type, expected_method):
"""
Tests the _select_moe_comm_method with various configurations.
Tests the _select_moe_comm_method with various configurations including quant_type.
"""
# Mock the NPUModelRunner instance and its dependencies
mock_runner = MagicMock(spec=NPUModelRunner)
@@ -60,15 +54,24 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
mock_runner.parallel_config.world_size_across_dp = world_size
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = quant_type
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
# Patch the helper functions
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True):
# Call the method under test
# Bind the real method to the mock object
method = NPUModelRunner._select_moe_comm_method(
mock_runner, num_tokens)
mock_runner, num_tokens, False)
# Assert the result
assert method == expected_method
@@ -83,6 +86,15 @@ def test_select_moe_comm_method_unsupported_soc():
mock_runner.parallel_config.enable_expert_parallel = True
mock_runner.mc2_tokens_capacity = 256
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = None
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
unsupported_soc = "UnsupportedSOC"
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
@@ -91,4 +103,4 @@ def test_select_moe_comm_method_unsupported_soc():
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
NPUModelRunner._select_moe_comm_method(mock_runner, 100)
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)