[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -21,37 +21,31 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method",
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
||||
[
|
||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, "allgather"),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, "allgather"),
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
|
||||
|
||||
# Case 2: A2 SOC
|
||||
# 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16)
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "mc2"),
|
||||
(AscendSocVersion.A2, True, 32, 256, 256, "mc2"),
|
||||
# 2.2: MC2 token capacity exceeded
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "allgather"),
|
||||
# 2.3: MC2 world size not met
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "allgather"),
|
||||
(AscendSocVersion.A2, True, 15, 100, 256, "allgather"),
|
||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
|
||||
|
||||
# Case 3: A3 SOC
|
||||
# 3.1: MC2 condition met (tokens <= capacity)
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, "mc2"),
|
||||
(AscendSocVersion.A3, True, 16, 256, 256, "mc2"),
|
||||
# 3.2: MC2 token capacity exceeded
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, "alltoall"),
|
||||
(AscendSocVersion.A3, True, 16, 500, 256, "alltoall"),
|
||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
|
||||
|
||||
# Case 4: A3 SOC
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
world_size, num_tokens, mc2_tokens_capacity,
|
||||
expected_method):
|
||||
quant_type, expected_method):
|
||||
"""
|
||||
Tests the _select_moe_comm_method with various configurations.
|
||||
Tests the _select_moe_comm_method with various configurations including quant_type.
|
||||
"""
|
||||
# Mock the NPUModelRunner instance and its dependencies
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
@@ -60,15 +54,24 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
mock_runner.parallel_config.world_size_across_dp = world_size
|
||||
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = quant_type
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
# Patch the helper functions
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
return_value=soc_version), \
|
||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
||||
return_value=True):
|
||||
|
||||
# Call the method under test
|
||||
# Bind the real method to the mock object
|
||||
method = NPUModelRunner._select_moe_comm_method(
|
||||
mock_runner, num_tokens)
|
||||
mock_runner, num_tokens, False)
|
||||
|
||||
# Assert the result
|
||||
assert method == expected_method
|
||||
@@ -83,6 +86,15 @@ def test_select_moe_comm_method_unsupported_soc():
|
||||
mock_runner.parallel_config.enable_expert_parallel = True
|
||||
mock_runner.mc2_tokens_capacity = 256
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = None
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
unsupported_soc = "UnsupportedSOC"
|
||||
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
@@ -91,4 +103,4 @@ def test_select_moe_comm_method_unsupported_soc():
|
||||
return_value=True), \
|
||||
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
|
||||
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100)
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)
|
||||
|
||||
Reference in New Issue
Block a user