[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import (
|
||||
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
|
||||
CustomDeepseekV2RowParallelLinear,
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear,
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
|
||||
|
||||
@@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = CustomDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from vllm_ascend.ascend_forward_context import get_dispatcher_name
|
||||
|
||||
|
||||
class TestGetDispatcherName(unittest.TestCase):
|
||||
|
||||
def test_get_dispatcher_name(self):
|
||||
result = get_dispatcher_name(1, False)
|
||||
assert result == "TokenDispatcherWithAllGather"
|
||||
result = get_dispatcher_name(4, False)
|
||||
assert result == "TokenDispatcherWithAll2AllV"
|
||||
result = get_dispatcher_name(16, True)
|
||||
assert result == "TokenDispatcherWithAll2AllV"
|
||||
result = get_dispatcher_name(16, False)
|
||||
assert result == "TokenDispatcherWithMC2"
|
||||
with mock.patch.dict(os.environ,
|
||||
{"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}):
|
||||
result = get_dispatcher_name(16, False)
|
||||
assert result == "TokenDispatcherWithAllGather"
|
||||
@@ -6,7 +6,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
@@ -216,3 +217,68 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
# Mock forward context with DP metadata
|
||||
mock_context = MagicMock()
|
||||
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
|
||||
[2, 5, 7])
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Setup DP group mock
|
||||
mock_dp_group = MagicMock()
|
||||
mock_dp_group.broadcast = MagicMock()
|
||||
mock_dp_group.all_reduce = MagicMock()
|
||||
mock_get_dp_group.return_value = mock_dp_group
|
||||
|
||||
# Mock all_reduce to just return input (simulate sum)
|
||||
def mock_all_reduce(tensor):
|
||||
return tensor * 2
|
||||
|
||||
mock_dp_group.all_reduce.side_effect = mock_all_reduce
|
||||
|
||||
# Setup config
|
||||
self.moe_config.dp_size = 3
|
||||
self.moe_config.dp_rank = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock gate for router logits recomputation
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
self.assertEqual(r_out.shape, (7, 2))
|
||||
|
||||
# Run finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should slice back to local: [3, 8]
|
||||
self.assertEqual(result.shape, (3, 8))
|
||||
|
||||
# Test with reduce_results=True and TP/EP > 1
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape, (3, 8))
|
||||
|
||||
@@ -22,10 +22,7 @@ import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
mock_moe_comm_method = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||
return hidden_states, router_logits
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
mock_fused_experts_result = torch.randn(16, 2)
|
||||
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
def mock_finalize(hidden_states, **kwargs):
|
||||
return hidden_states
|
||||
|
||||
captured_dispatchers = {}
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_method_name="mc2commimpl",
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
@@ -131,14 +84,12 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
@@ -150,6 +101,8 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
@@ -157,22 +110,20 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
'mock_moe_comm_method': mock_moe_comm_method,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
@@ -338,9 +289,7 @@ class TestAscendFusedMoe:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
@@ -394,25 +343,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
@@ -438,35 +372,22 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
@@ -493,8 +414,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
@@ -574,7 +497,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_forward_context.moe_comm_method_name = "mc2commimpl"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
@@ -618,8 +541,6 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
|
||||
@@ -23,8 +23,7 @@ from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -521,99 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
|
||||
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
|
||||
@@ -21,37 +21,31 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method",
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
||||
[
|
||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, "allgather"),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, "allgather"),
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
|
||||
|
||||
# Case 2: A2 SOC
|
||||
# 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16)
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "mc2"),
|
||||
(AscendSocVersion.A2, True, 32, 256, 256, "mc2"),
|
||||
# 2.2: MC2 token capacity exceeded
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "allgather"),
|
||||
# 2.3: MC2 world size not met
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "allgather"),
|
||||
(AscendSocVersion.A2, True, 15, 100, 256, "allgather"),
|
||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
|
||||
|
||||
# Case 3: A3 SOC
|
||||
# 3.1: MC2 condition met (tokens <= capacity)
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, "mc2"),
|
||||
(AscendSocVersion.A3, True, 16, 256, 256, "mc2"),
|
||||
# 3.2: MC2 token capacity exceeded
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, "alltoall"),
|
||||
(AscendSocVersion.A3, True, 16, 500, 256, "alltoall"),
|
||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
|
||||
|
||||
# Case 4: A3 SOC
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
world_size, num_tokens, mc2_tokens_capacity,
|
||||
expected_method):
|
||||
quant_type, expected_method):
|
||||
"""
|
||||
Tests the _select_moe_comm_method with various configurations.
|
||||
Tests the _select_moe_comm_method with various configurations including quant_type.
|
||||
"""
|
||||
# Mock the NPUModelRunner instance and its dependencies
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
@@ -60,15 +54,24 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
mock_runner.parallel_config.world_size_across_dp = world_size
|
||||
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = quant_type
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
# Patch the helper functions
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
return_value=soc_version), \
|
||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
||||
return_value=True):
|
||||
|
||||
# Call the method under test
|
||||
# Bind the real method to the mock object
|
||||
method = NPUModelRunner._select_moe_comm_method(
|
||||
mock_runner, num_tokens)
|
||||
mock_runner, num_tokens, False)
|
||||
|
||||
# Assert the result
|
||||
assert method == expected_method
|
||||
@@ -83,6 +86,15 @@ def test_select_moe_comm_method_unsupported_soc():
|
||||
mock_runner.parallel_config.enable_expert_parallel = True
|
||||
mock_runner.mc2_tokens_capacity = 256
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = None
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
unsupported_soc = "UnsupportedSOC"
|
||||
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
@@ -91,4 +103,4 @@ def test_select_moe_comm_method_unsupported_soc():
|
||||
return_value=True), \
|
||||
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
|
||||
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100)
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)
|
||||
|
||||
Reference in New Issue
Block a user