[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)

### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-16 11:06:00 +08:00
committed by GitHub
parent 0aba644633
commit 18ca7861f6
18 changed files with 523 additions and 596 deletions

View File

@@ -23,8 +23,7 @@ from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):

View File

@@ -1,22 +0,0 @@
import os
import unittest
from unittest import mock
from vllm_ascend.ascend_forward_context import get_dispatcher_name
class TestGetDispatcherName(unittest.TestCase):
def test_get_dispatcher_name(self):
result = get_dispatcher_name(1, False)
assert result == "TokenDispatcherWithAllGather"
result = get_dispatcher_name(4, False)
assert result == "TokenDispatcherWithAll2AllV"
result = get_dispatcher_name(16, True)
assert result == "TokenDispatcherWithAll2AllV"
result = get_dispatcher_name(16, False)
assert result == "TokenDispatcherWithMC2"
with mock.patch.dict(os.environ,
{"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"}):
result = get_dispatcher_name(16, False)
assert result == "TokenDispatcherWithAllGather"

View File

@@ -6,7 +6,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
@@ -216,3 +217,68 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape[0], 3)
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
)
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
)
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
mock_tp_all_reduce,
mock_get_dp_group):
# Mock forward context with DP metadata
mock_context = MagicMock()
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
[2, 5, 7])
mock_get_forward_context.return_value = mock_context
# Setup DP group mock
mock_dp_group = MagicMock()
mock_dp_group.broadcast = MagicMock()
mock_dp_group.all_reduce = MagicMock()
mock_get_dp_group.return_value = mock_dp_group
# Mock all_reduce to just return input (simulate sum)
def mock_all_reduce(tensor):
return tensor * 2
mock_dp_group.all_reduce.side_effect = mock_all_reduce
# Setup config
self.moe_config.dp_size = 3
self.moe_config.dp_rank = 1
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
# Local inputs
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
# Mock gate for router logits recomputation
mock_gate = MagicMock()
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8))
self.assertEqual(r_out.shape, (7, 2))
# Run finalize
result = layer.finalize(h_out, reduce_results=False)
# Should slice back to local: [3, 8]
self.assertEqual(result.shape, (3, 8))
# Test with reduce_results=True and TP/EP > 1
mock_tp_all_reduce.return_value = result
result_with_tp = layer.finalize(h_out, reduce_results=True)
self.assertEqual(result_with_tp.shape, (3, 8))

View File

@@ -22,10 +22,7 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_moe_comm_method = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
def mock_prepare(hidden_states, router_logits, **kwargs):
return hidden_states, router_logits
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_moe_comm_method.prepare.side_effect = mock_prepare
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_fused_experts_result = torch.randn(16, 2)
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
def mock_finalize(hidden_states, **kwargs):
return hidden_states
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_moe_comm_method.finalize.side_effect = mock_finalize
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -131,14 +84,12 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -150,6 +101,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
@@ -157,22 +110,20 @@ def mock_dist_env(mocker: MockerFixture):
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
'mock_moe_comm_method': mock_moe_comm_method,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
@@ -338,9 +289,7 @@ class TestAscendFusedMoe:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
@@ -394,25 +343,10 @@ class TestAscendUnquantizedFusedMoEMethod:
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
@@ -438,35 +372,22 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts=global_num_experts,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
@@ -493,8 +414,10 @@ class TestAscendUnquantizedFusedMoEMethod:
expert_map=expert_map,
is_prefill=is_prefill)
expected_shape = (16, 2)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@@ -574,7 +497,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
@@ -618,8 +541,6 @@ class TestUnifiedApplyMLP(TestBase):
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()

View File

@@ -23,8 +23,7 @@ from tests.ut.base import TestBase
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
class TestTokenDispatcherWithMC2(TestBase):
@@ -521,99 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -21,37 +21,31 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
# yapf: disable
@pytest.mark.parametrize(
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, expected_method",
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendSocVersion.A2, False, 8, 100, 256, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, "allgather"),
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
# Case 2: A2 SOC
# 2.1: MC2 conditions met (tokens <= capacity, world_size >= 16)
(AscendSocVersion.A2, True, 16, 100, 256, "mc2"),
(AscendSocVersion.A2, True, 32, 256, 256, "mc2"),
# 2.2: MC2 token capacity exceeded
(AscendSocVersion.A2, True, 16, 257, 256, "allgather"),
# 2.3: MC2 world size not met
(AscendSocVersion.A2, True, 8, 100, 256, "allgather"),
(AscendSocVersion.A2, True, 15, 100, 256, "allgather"),
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
# Case 3: A3 SOC
# 3.1: MC2 condition met (tokens <= capacity)
(AscendSocVersion.A3, True, 8, 100, 256, "mc2"),
(AscendSocVersion.A3, True, 16, 256, 256, "mc2"),
# 3.2: MC2 token capacity exceeded
(AscendSocVersion.A3, True, 8, 257, 256, "alltoall"),
(AscendSocVersion.A3, True, 16, 500, 256, "alltoall"),
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
# Case 4: A3 SOC
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
world_size, num_tokens, mc2_tokens_capacity,
expected_method):
quant_type, expected_method):
"""
Tests the _select_moe_comm_method with various configurations.
Tests the _select_moe_comm_method with various configurations including quant_type.
"""
# Mock the NPUModelRunner instance and its dependencies
mock_runner = MagicMock(spec=NPUModelRunner)
@@ -60,15 +54,24 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
mock_runner.parallel_config.world_size_across_dp = world_size
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = quant_type
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
# Patch the helper functions
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True):
# Call the method under test
# Bind the real method to the mock object
method = NPUModelRunner._select_moe_comm_method(
mock_runner, num_tokens)
mock_runner, num_tokens, False)
# Assert the result
assert method == expected_method
@@ -83,6 +86,15 @@ def test_select_moe_comm_method_unsupported_soc():
mock_runner.parallel_config.enable_expert_parallel = True
mock_runner.mc2_tokens_capacity = 256
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = None
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
unsupported_soc = "UnsupportedSOC"
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
@@ -91,4 +103,4 @@ def test_select_moe_comm_method_unsupported_soc():
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
NPUModelRunner._select_moe_comm_method(mock_runner, 100)
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)

View File

@@ -42,17 +42,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
return FusedMoEState.MC2
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
if ep_size == 1:
return "TokenDispatcherWithAllGather"
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1:
return "TokenDispatcherWithAllGather"
elif ep_size < 16 or with_prefill:
return "TokenDispatcherWithAll2AllV"
else:
return "TokenDispatcherWithMC2"
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
@@ -97,11 +86,6 @@ def set_ascend_forward_context(
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
dispatcher = get_token_dispatcher(dispatcher_name)
forward_context.token_dispatcher = dispatcher
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False

View File

@@ -32,8 +32,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl)
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -280,17 +280,17 @@ class AscendFusedMoE(FusedMoE):
num_redundant_experts,
has_bias,
)
setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)
self.hidden_size = hidden_size
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
for method in {
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl
}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]

View File

@@ -19,13 +19,10 @@ import os
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
get_tensor_model_parallel_world_size)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
@@ -39,72 +36,18 @@ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p)
def unified_fused_experts_eager(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
w1_scale: Optional[torch.Tensor] = None,
w1_scale_bias: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w2_scale_bias: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
fusion_mlp: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=with_quant)
expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=results["group_list"],
dynamic_scale=results.get("dynamic_scale"),
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"),
with_quant=with_quant,
fusion=fusion_mlp)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
@@ -182,17 +125,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
return unified_fused_experts_eager(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
shared_experts=shared_experts,
mc2_mask=kwargs.get(
"mc2_mask", None),
with_quant=False)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
need_trans=True)
class AscendFusedMoE(FusedMoE):
@@ -354,18 +298,20 @@ class AscendFusedMoE(FusedMoE):
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.token_dispatcher = None
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
from vllm_ascend.ops.moe.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
for method in {
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl
}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -401,10 +347,7 @@ class AscendFusedMoE(FusedMoE):
else:
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
@@ -422,63 +365,16 @@ class AscendFusedMoE(FusedMoE):
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
moe_comm_method_name = forward_context.moe_comm_method_name
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
max_tokens_across_dp = forward_context.max_tokens_across_dp
if num_tokens < max_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_dp_cpu)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
enable_shared_expert_dp=self.enable_shared_expert_dp,
rm_router_logits=self.rm_router_logits,
replace_allreduce=replace_allreduce,
gate=gate)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
@@ -501,7 +397,6 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=None,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
@@ -510,44 +405,9 @@ class AscendFusedMoE(FusedMoE):
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce and not self.enable_shared_expert_dp):
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
final_hidden_states = get_dp_group().all_reduce(
e_hidden_states)
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = get_dp_group().reduce_scatter(
e_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
]:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=e_hidden_states,
reduce_results=(not self.all_reduce_merge))
if shared_experts:
return final_hidden_states, shared_hidden_states

View File

@@ -28,6 +28,16 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
class FusedMoEPrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
in distributed environments. Subclasses implement specific communication strategies
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
broadcasting, and reduction across TP/DP/EP groups.
Attributes:
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@@ -40,22 +50,65 @@ class FusedMoEPrepareAndFinalize(ABC):
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
- Slicing across tensor-parallel ranks
- Broadcasting across data-parallel ranks
- Recomputing router logits if needed
Args:
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
rm_router_logits (bool): Discard input router_logits and recompute via gate
replace_allreduce (bool): Bypass default all-reduce behavior
gate (nn.Module, optional): Gate network to recompute router_logits if needed
Returns:
Tuple of:
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
raise NotImplementedError("Combine function not implemented.")
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
- Reducing or scattering across DP ranks
- Unpadding to original token count
- Applying all-reduce across TP/EP if requested
Args:
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
Returns:
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
"""
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using MC2 (Memory-Centric Communication).
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
"""
Restore original TP configuration.
vLLM flattens TP and DP into a single dimension; this method recovers
the true TP world size and rank for correct tensor slicing.
"""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
@@ -66,9 +119,17 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
2. Pad `hidden_states` and `router_logits` to target length if needed.
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
4. Split and return corresponding `mc2_mask`.
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -80,12 +141,14 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
# Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# Slice across TP ranks
if self.tp_size > 1:
if not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
@@ -96,8 +159,9 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states
self.split_hidden_states = split_hidden_states # Save for finalize
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask,
self.tp_size,
dim=0)
@@ -107,16 +171,22 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
@@ -124,14 +194,18 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
@@ -142,12 +216,23 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
@@ -171,9 +256,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Also, unpad the hidden states if needed.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
@@ -188,6 +277,11 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-Gather + Reduce-Scatter.
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
Uses `max_tokens_across_dp` from forward_context for padding alignment.
"""
def prepare(self,
hidden_states: torch.Tensor,
@@ -196,8 +290,16 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
self.rm_router_logits = rm_router_logits
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
2. Pad local tensors to that size.
3. All-gather across DP group to form global input tensor.
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
@@ -209,14 +311,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
if not self.rm_router_logits:
if not rm_router_logits:
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if rm_router_logits:
router_logits, _ = gate(hidden_states) # Recompute globally
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
@@ -225,9 +328,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
2. Slice to original local token count.
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
When TP size > 1, all-reduce the hidden states to get the final output.
Returns:
Tensor with shape [original_local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
@@ -238,3 +346,101 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using Naive Multicast (point-to-point broadcast).
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
"""
def _naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
"""
Naive multicast implementation:
1. Create global buffer sized by total tokens across DP.
2. Current rank copies its slice into its designated buffer region.
3. Each rank broadcasts its slice to all others via P2P.
Args:
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
Returns:
torch.Tensor: Global tensor [total_tokens, hidden_size]
"""
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
# Copy local slice into buffer
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
buffer[start:end, :].copy_(x)
# Broadcast each slice to all ranks
for idx in range(self.moe_config.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu)
if rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert:
- All-reduce across DP
- Slice to current rank's token range using cu_tokens_across_dp_cpu
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
hidden_states = get_dp_group().all_reduce(
hidden_states) # Sum across DP
hidden_states = hidden_states[start:end, :]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states

View File

@@ -23,7 +23,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
@@ -82,8 +83,6 @@ class MoECommMethod(ABC):
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
@@ -91,13 +90,6 @@ class MoECommMethod(ABC):
global_redundant_expert_num: int = 0,
need_trans: bool = False) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
@@ -114,8 +106,8 @@ class MoECommMethod(ABC):
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=self.mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
@@ -135,12 +127,13 @@ class MoECommMethod(ABC):
w2_scale_bias=w2_scale_bias,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
need_trans=need_trans)
hidden_states[:] = self.token_dispatcher.token_combine(
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output)
return hidden_states
return final_hidden_states
@abstractmethod
def _get_token_dispatcher(self):
@@ -296,3 +289,32 @@ class AlltoAllCommImpl(MoECommMethod):
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
class NaiveMulticastCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)

View File

@@ -21,7 +21,6 @@ import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.utils import dispose_tensor, is_310p
@@ -77,7 +76,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
if w1_scale_bias is None and is_mc2:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)

View File

@@ -22,45 +22,17 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 \
and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC):
@@ -93,9 +65,9 @@ class MoETokenDispatcher(ABC):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -192,9 +164,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -218,6 +190,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
if self.with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
@@ -331,9 +308,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -418,9 +395,9 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -530,9 +507,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):

View File

@@ -24,9 +24,7 @@ from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -275,14 +273,6 @@ class AscendW4A8DynamicFusedMoEMethod:
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
@@ -291,7 +281,8 @@ class AscendW4A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype)
return unified_fused_experts_eager(
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -302,14 +293,13 @@ class AscendW4A8DynamicFusedMoEMethod:
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True)
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
group_num, k, n = weight.shape

View File

@@ -24,9 +24,7 @@ from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
@@ -233,14 +231,6 @@ class AscendW8A8DynamicFusedMoEMethod:
expert_map=expert_map,
)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
@@ -249,7 +239,8 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype)
return unified_fused_experts_eager(
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale_fp32,
@@ -258,15 +249,13 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
use_int8_w8a8=True,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True,
fusion_mlp=True)
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share)
def process_weights_after_loading(self, layer):
if self.transpose_weight:

View File

@@ -117,8 +117,11 @@ class EagleProposer(Proposer):
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
moe_comm_method = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
@@ -447,12 +450,20 @@ class EagleProposer(Proposer):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
moe_comm_method = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
@@ -483,6 +494,10 @@ class EagleProposer(Proposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
moe_comm_method = self.runner._select_moe_comm_method(
input_batch_size, False)
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
@@ -553,6 +568,7 @@ class EagleProposer(Proposer):
# Run the model.
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(

View File

@@ -112,6 +112,10 @@ class MtpProposer(Proposer):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
moe_comm_method = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
@@ -142,6 +146,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
@@ -411,6 +416,9 @@ class MtpProposer(Proposer):
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_method = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
@@ -419,6 +427,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):

View File

@@ -1663,7 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_connector_output=kv_connector_output,
)
def _select_moe_comm_method(self, num_tokens: int) -> str:
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> str:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1687,6 +1688,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
"""
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
@@ -1694,12 +1697,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
moe_comm_method = "mc2"
else:
moe_comm_method = "allgather"
if quant_type == "w4a8_dynamic":
moe_comm_method = "alltoall"
else:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_method == "allgather" and with_prefill:
moe_comm_method = "naivemulticast"
if is_global_first_rank():
logger.debug(f"num_tokens: {num_tokens}, "
f"moe_comm_method: {moe_comm_method}")
@@ -1728,7 +1738,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
moe_comm_method = self._select_moe_comm_method(num_input_tokens)
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
@@ -2100,7 +2111,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
moe_comm_method = self._select_moe_comm_method(num_tokens)
moe_comm_method = self._select_moe_comm_method(num_tokens,
with_prefill)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using