[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -17,56 +17,7 @@ from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
mock_is_310p.return_value = False
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
mock_swiglu.side_effect = lambda x: x
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class TestLoadWeight(TestBase):

View File

@@ -23,6 +23,7 @@ from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture(autouse=True)
def setup_vllm_config_mock(mocker: MockerFixture):
mock_hf_config = MagicMock()
mock_hf_config.model_type = "llama"
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method = MagicMock()
@@ -74,7 +95,7 @@ def mock_dist_env(mocker: MockerFixture):
mock_forward_context_obj = MagicMock(
moe_comm_method=mock_moe_comm_method,
moe_comm_method_name="mc2commimpl",
moe_comm_type=MoECommType.MC2,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -104,12 +125,6 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj), \
@@ -501,7 +516,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.moe_comm_method_name = "mc2commimpl"
mock_forward_context.moe_comm_type = MoECommType.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False

View File

@@ -24,6 +24,7 @@ class TestMoECommMethod(TestBase):
self.moe_config.dp_group = MagicMock()
self.moe_config.num_global_redundant_experts = 0
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -31,7 +32,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
def test_all_gather_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"
@@ -64,13 +69,18 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
)
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "mc2"
@@ -104,6 +114,7 @@ class TestMoECommMethod(TestBase):
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
@@ -111,7 +122,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
def test_alltoall_comm_impl(self, mock_token_dispatcher,
mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "alltoall"
@@ -140,6 +155,7 @@ class TestMoECommMethod(TestBase):
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@patch(
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -148,7 +164,11 @@ class TestMoECommMethod(TestBase):
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
mock_get_forward_context,
mock_get_current_vllm_config):
# Mock vLLM config
mock_get_current_vllm_config.return_value = MagicMock()
# Mock forward context
mock_context = MagicMock()
mock_context.moe_comm_method = "all_gather"

View File

@@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
output_size = 56
group_size = 2
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config):
get_current_vllm_config, mock_get_ascend_config):
# Mock ascend config
mock_ascend_config = Mock()
mock_ascend_config.dynamic_eplb = False
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()

View File

@@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch
import pytest
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import AscendSocVersion
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -24,21 +25,21 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
# Case 4: A3 SOC
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,

View File

@@ -22,6 +22,13 @@ class FusedMoEState(Enum):
All2AllSeq = 5
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
NAIVE_MULTICAST = 3
# TODO(zzzzwwjj): add soc_version to choose branch
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
@@ -52,7 +59,7 @@ def set_ascend_forward_context(
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: str = "",
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
@@ -72,7 +79,11 @@ def set_ascend_forward_context(
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if

View File

@@ -23,106 +23,23 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl)
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
@@ -178,20 +95,6 @@ def forward_oot(
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
@@ -277,13 +180,7 @@ class AscendFusedMoE(FusedMoE):
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
for method in {
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl
}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
@@ -307,8 +204,8 @@ class AscendFusedMoE(FusedMoE):
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -318,10 +215,6 @@ class AscendFusedMoE(FusedMoE):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
@@ -449,8 +342,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
_, fused_out = AscendFusedMoE.forward(

View File

@@ -41,9 +41,7 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl)
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state,
@@ -339,13 +337,7 @@ class AscendFusedMoE(FusedMoE):
self.moe_config.mc2_group = get_mc2_group()
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
for method in {
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl
}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
@@ -360,22 +352,6 @@ class AscendFusedMoE(FusedMoE):
if self.moe_load is not None:
self.moe_load.zero_()
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -412,9 +388,6 @@ class AscendFusedMoE(FusedMoE):
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True
moe_comm_method_name = forward_context.moe_comm_method_name
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,

View File

@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Dict, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
@@ -28,13 +31,31 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
TokenDispatcherWithMC2,
TokenDispatcherWithMoge)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.mc2_mask = None
@@ -113,8 +134,8 @@ class MoECommMethod(ABC):
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"]
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
@@ -126,6 +147,7 @@ class MoECommMethod(ABC):
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
@@ -170,94 +192,21 @@ class AllGatherCommImpl(MoECommMethod):
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
if self.model_type == "PanguProMoE":
return TokenDispatcherWithMoge(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
else:
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.

View File

@@ -21,6 +21,7 @@ import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import dispose_tensor, is_310p
@@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)

View File

@@ -377,14 +377,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
# mypy: disable-error-code="override"
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
class TokenDispatcherWithMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.local_num_experts = self.num_experts // self.ep_size
self.local_num_group = self.top_k // self.ep_size
self.bsz = None
def token_dispatch(self,
@@ -401,17 +400,6 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -445,7 +433,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
self.bsz, self.top_k // self.ep_size, -1).sum(1)
return final_hidden_states

View File

@@ -117,11 +117,11 @@ class EagleProposer(Proposer):
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
moe_comm_method = self.runner._select_moe_comm_method(
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
self.vllm_config,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
@@ -454,7 +454,7 @@ class EagleProposer(Proposer):
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
moe_comm_method = self.runner._select_moe_comm_method(
moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
# copy inputs to buffer for cudagraph
@@ -463,7 +463,7 @@ class EagleProposer(Proposer):
attn_metadata.block_tables = block_table.to(device)
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
@@ -495,7 +495,7 @@ class EagleProposer(Proposer):
else:
input_batch_size = batch_size
moe_comm_method = self.runner._select_moe_comm_method(
moe_comm_type = self.runner._select_moe_comm_method(
input_batch_size, False)
attn_metadata.num_actual_tokens = batch_size
@@ -568,7 +568,7 @@ class EagleProposer(Proposer):
# Run the model.
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(

View File

@@ -113,7 +113,7 @@ class MtpProposer(Proposer):
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
moe_comm_method = self.runner._select_moe_comm_method(
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
is_running_torchair = self.torchair_graph_enabled and \
@@ -146,7 +146,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
@@ -425,7 +425,7 @@ class MtpProposer(Proposer):
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_method = self.runner._select_moe_comm_method(
moe_comm_type = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
@@ -440,7 +440,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):

View File

@@ -94,7 +94,8 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
@@ -1860,7 +1861,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> str:
with_prefill: bool) -> MoECommType:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1881,36 +1882,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ValueError: If the soc version is unsupported.
Returns:
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
MoECommType: The selected MoE communication method.
"""
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
model_type = self.vllm_config.model_config.hf_config.model_type
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A2}:
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
moe_comm_method = "mc2"
if (num_tokens <= self.mc2_tokens_capacity
and self.parallel_config.world_size_across_dp >= 16):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
if quant_type == "w4a8_dynamic":
moe_comm_method = "alltoall"
moe_comm_type = MoECommType.ALLTOALL
else:
moe_comm_method = "allgather"
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
moe_comm_type = (MoECommType.MC2
if num_tokens <= self.mc2_tokens_capacity else
MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_method == "allgather" and with_prefill:
moe_comm_method = "naivemulticast"
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
moe_comm_type = MoECommType.NAIVE_MULTICAST
# PanguProMoE only supports allgather
if model_type == "PanguProMoE":
moe_comm_type = MoECommType.ALLGATHER
if is_global_first_rank():
logger.debug(f"num_tokens: {num_tokens}, "
f"moe_comm_method: {moe_comm_method}")
return moe_comm_method
f"moe_comm_type: {moe_comm_type}")
return moe_comm_type
@torch.inference_mode()
def execute_model(
@@ -1942,8 +1951,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
@@ -1962,7 +1971,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
@@ -2351,8 +2360,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
moe_comm_method = self._select_moe_comm_method(num_tokens,
with_prefill)
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
@@ -2472,7 +2480,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,