[Fix][MoE] Refine MoE communication strategy (#2734)

### What this PR does / why we need it?
Refactors the Mixture-of-Experts (MoE) communication method selection
logic. The choice between all-gather, all-to-all, and mc2 is now
determined by expert parallel configuration, SoC version (A2/A3), and
token count for better performance.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Added.


- vLLM version: v0.10.1.1
- vLLM main:
eafa8dcde6

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-05 09:04:04 +08:00
committed by GitHub
parent 4c90fa79ca
commit 83eb40a51c
3 changed files with 123 additions and 9 deletions

View File

@@ -0,0 +1,94 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, patch
import pytest
from vllm_ascend.utils import AscendSocVersion
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
# yapf: disable
@pytest.mark.parametrize(
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendSocVersion.A2, False, 8, 100, 256, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, "allgather"),
# Case 2: A2 SOC
# 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16)
(AscendSocVersion.A2, True, 16, 100, 256, "mc2"),
(AscendSocVersion.A2, True, 32, 256, 256, "mc2"),
# 2.2: MC2 token capacity exceeded
(AscendSocVersion.A2, True, 16, 257, 256, "allgather"),
# 2.3: MC2 world size not met
(AscendSocVersion.A2, True, 8, 100, 256, "allgather"),
(AscendSocVersion.A2, True, 15, 100, 256, "allgather"),
# Case 3: A3 SOC
# 3.1: MC2 condition met (tokens <= capacity)
(AscendSocVersion.A3, True, 8, 100, 256, "mc2"),
(AscendSocVersion.A3, True, 16, 256, 256, "mc2"),
# 3.2: MC2 token capacity exceeded
(AscendSocVersion.A3, True, 8, 257, 256, "alltoall"),
(AscendSocVersion.A3, True, 16, 500, 256, "alltoall"),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
world_size, num_tokens, mc2_tokens_capacity,
expected_method):
"""
Tests the _select_moe_comm_method with various configurations.
"""
# Mock the NPUModelRunner instance and its dependencies
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
mock_runner.parallel_config.world_size = world_size
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
# Patch the helper functions
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True):
# Call the method under test
method = NPUModelRunner._select_moe_comm_method(
mock_runner, num_tokens)
# Assert the result
assert method == expected_method
def test_select_moe_comm_method_unsupported_soc():
"""
Tests that _select_moe_comm_method raises ValueError for an unsupported SOC.
"""
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.enable_expert_parallel = True
mock_runner.mc2_tokens_capacity = 256
unsupported_soc = "UnsupportedSOC"
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=unsupported_soc), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
NPUModelRunner._select_moe_comm_method(mock_runner, 100)

View File

@@ -482,11 +482,6 @@ class AscendFusedMoE(FusedMoE):
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
# TODO: Can we refactor this logic to model_runner?
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
if self.moe_config.ep_size < 16:
moe_comm_method_name = "allgathercommimpl"
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(

View File

@@ -1434,14 +1434,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int) -> str:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.
a. For A2, we choose from MC2 and all-gather.
b. For A3, we choose from MC2 and all-to-all.
In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.
Args:
num_tokens (int): The number of tokens in the current batch.
Raises:
ValueError: If the soc version is unsupported.
Returns:
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
"""
soc_version = get_ascend_soc_version()
if num_tokens <= self.mc2_tokens_capacity:
moe_comm_method = "mc2"
elif soc_version in {AscendSocVersion.A2}:
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A2}:
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size >= 16:
moe_comm_method = "mc2"
else:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "alltoall"
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")