From 83eb40a51cb30c654c80a9203b3d9d7bd0351e4e Mon Sep 17 00:00:00 2001 From: yiz-liu <136800916+yiz-liu@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:04:04 +0800 Subject: [PATCH] [Fix][MoE] Refine MoE communication strategy (#2734) ### What this PR does / why we need it? Refactors the Mixture-of-Experts (MoE) communication method selection logic. The choice between all-gather, all-to-all, and mc2 is now determined by expert parallel configuration, SoC version (A2/A3), and token count for better performance. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Added. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eafa8dcde63d625350ed618db4dd1cbcbaae77a1 --------- Signed-off-by: Yizhou Liu --- tests/ut/worker/test_model_runner_v1.py | 94 +++++++++++++++++++++++++ vllm_ascend/ops/common_fused_moe.py | 5 -- vllm_ascend/worker/model_runner_v1.py | 33 +++++++-- 3 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 tests/ut/worker/test_model_runner_v1.py diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py new file mode 100644 index 0000000..eb83d30 --- /dev/null +++ b/tests/ut/worker/test_model_runner_v1.py @@ -0,0 +1,94 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from unittest.mock import MagicMock, patch + +import pytest + +from vllm_ascend.utils import AscendSocVersion +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + +# yapf: disable +@pytest.mark.parametrize( + "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method", + [ + # Case 1: Expert parallel is disabled, should always be 'allgather' + (AscendSocVersion.A2, False, 8, 100, 256, "allgather"), + (AscendSocVersion.A3, False, 16, 500, 256, "allgather"), + + # Case 2: A2 SOC + # 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16) + (AscendSocVersion.A2, True, 16, 100, 256, "mc2"), + (AscendSocVersion.A2, True, 32, 256, 256, "mc2"), + # 2.2: MC2 token capacity exceeded + (AscendSocVersion.A2, True, 16, 257, 256, "allgather"), + # 2.3: MC2 world size not met + (AscendSocVersion.A2, True, 8, 100, 256, "allgather"), + (AscendSocVersion.A2, True, 15, 100, 256, "allgather"), + + # Case 3: A3 SOC + # 3.1: MC2 condition met (tokens <= capacity) + (AscendSocVersion.A3, True, 8, 100, 256, "mc2"), + (AscendSocVersion.A3, True, 16, 256, 256, "mc2"), + # 3.2: MC2 token capacity exceeded + (AscendSocVersion.A3, True, 8, 257, 256, "alltoall"), + (AscendSocVersion.A3, True, 16, 500, 256, "alltoall"), + + ]) +# yapf: enable +def test_select_moe_comm_method(soc_version, enable_expert_parallel, + world_size, num_tokens, mc2_tokens_capacity, + expected_method): + """ + Tests the _select_moe_comm_method with various configurations. + """ + # Mock the NPUModelRunner instance and its dependencies + mock_runner = MagicMock(spec=NPUModelRunner) + mock_runner.parallel_config = MagicMock() + mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel + mock_runner.parallel_config.world_size = world_size + mock_runner.mc2_tokens_capacity = mc2_tokens_capacity + + # Patch the helper functions + with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', + return_value=soc_version), \ + patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank', + return_value=True): + + # Call the method under test + method = NPUModelRunner._select_moe_comm_method( + mock_runner, num_tokens) + + # Assert the result + assert method == expected_method + + +def test_select_moe_comm_method_unsupported_soc(): + """ + Tests that _select_moe_comm_method raises ValueError for an unsupported SOC. + """ + mock_runner = MagicMock(spec=NPUModelRunner) + mock_runner.parallel_config = MagicMock() + mock_runner.parallel_config.enable_expert_parallel = True + mock_runner.mc2_tokens_capacity = 256 + + unsupported_soc = "UnsupportedSOC" + + with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version', + return_value=unsupported_soc), \ + patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank', + return_value=True), \ + pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"): + + NPUModelRunner._select_moe_comm_method(mock_runner, 100) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 607991c..5cb2d6f 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -482,11 +482,6 @@ class AscendFusedMoE(FusedMoE): forward_context = get_forward_context() moe_comm_method_name = forward_context.moe_comm_method_name - # TODO: Can we refactor this logic to model_runner? - # TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now - if self.moe_config.ep_size < 16: - moe_comm_method_name = "allgathercommimpl" - forward_context.moe_comm_method = getattr(self, moe_comm_method_name) hidden_states, router_logits = forward_context.moe_comm_method.prepare( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3068d36..9e8b58e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1434,14 +1434,39 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) def _select_moe_comm_method(self, num_tokens: int) -> str: + """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all + are designed for expert parallelism. + 2. If expert parallel is enabled, we need to consider the soc version and the + number of tokens. This is based on the observation that all-gather is more + efficient than all-to-all when running on A2. + + a. For A2, we choose from MC2 and all-gather. + + b. For A3, we choose from MC2 and all-to-all. + + In both cases, we use MC2 when the number of tokens is smaller than + a its capacity threshold. + + Args: + num_tokens (int): The number of tokens in the current batch. + + Raises: + ValueError: If the soc version is unsupported. + + Returns: + str: The selected MoE communication method, either "allgather", "mc2", or "alltoall". + """ soc_version = get_ascend_soc_version() - if num_tokens <= self.mc2_tokens_capacity: - moe_comm_method = "mc2" - elif soc_version in {AscendSocVersion.A2}: + if not self.parallel_config.enable_expert_parallel: moe_comm_method = "allgather" + elif soc_version in {AscendSocVersion.A2}: + if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size >= 16: + moe_comm_method = "mc2" + else: + moe_comm_method = "allgather" elif soc_version in {AscendSocVersion.A3}: - moe_comm_method = "alltoall" + moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall" else: raise ValueError(f"Unsupported soc_version: {soc_version}")