[Refactor] Adjustments to moe_comm_method selection process (#3001)
### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -17,56 +17,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
|
||||
def test_fused_experts_moge(self):
|
||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
||||
]
|
||||
|
||||
mock_swiglu.side_effect = lambda x: x
|
||||
|
||||
hidden_states = torch.randn(4, 128)
|
||||
w1 = torch.randn(4, 256, 128)
|
||||
w2 = torch.randn(4, 128, 128)
|
||||
topk_weights = torch.rand(4, 1)
|
||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
||||
top_k = 1
|
||||
global_num_experts = 4
|
||||
|
||||
moe_parallel_config = type(
|
||||
'MockConfig', (), {
|
||||
'ep_size': 1,
|
||||
'tp_size': 1,
|
||||
'dp_size': 1,
|
||||
'tp_rank': 0,
|
||||
'dp_rank': 0,
|
||||
'ep_rank': 0,
|
||||
'use_ep': True
|
||||
})()
|
||||
|
||||
output = fused_experts_moge(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
@@ -23,6 +23,7 @@ from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
@@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.model_type = "llama"
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_moe_comm_method = MagicMock()
|
||||
@@ -74,7 +95,7 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_method_name="mc2commimpl",
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
@@ -104,12 +125,6 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
@@ -501,7 +516,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.moe_comm_method_name = "mc2commimpl"
|
||||
mock_forward_context.moe_comm_type = MoECommType.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "mc2"
|
||||
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "alltoall"
|
||||
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context):
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
|
||||
@@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config):
|
||||
get_current_vllm_config, mock_get_ascend_config):
|
||||
# Mock ascend config
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.dynamic_eplb = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import AscendSocVersion
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
@@ -24,21 +25,21 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
||||
[
|
||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
|
||||
|
||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
|
||||
|
||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
|
||||
|
||||
# Case 4: A3 SOC
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
|
||||
Reference in New Issue
Block a user