[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)

### What this PR does / why we need it?
1. Move additional functionalities from fused_moe.py to
common_fused_moe.py and remove fused_moe.py
2. Remove unnecessary custom classes from qwen3_moe.py, and it will be
completely removed after we release vllm-ascend v0.11.0

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:

1. Enable/Disable EP
3. Aclgraph & eager
4. SP


- vLLM version: v0.11.0

---------

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-10-09 14:12:46 +08:00
committed by GitHub
parent a36e3da78e
commit 94dd832815
17 changed files with 175 additions and 1110 deletions

View File

@@ -96,7 +96,7 @@ def mock_distributed():
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,

View File

@@ -1,68 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import unittest
import torch
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
class DummyRMSNorm:
def __init__(self, dim: int, eps: float = 1e-6):
self.dim = dim
self.eps = eps
def __call__(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
denom = (mean_sq + self.eps).sqrt()
return x / denom
class TestCustomQwen3MoeAttention(unittest.TestCase):
def setUp(self):
self.batch = 2
self.seq_len = 3
self.q_size = 8
self.kv_size = 8
self.head_dim = 4
self.rms_eps = 1e-6
total_dim = self.q_size + 2 * self.kv_size
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
dtype=torch.float32).reshape(
self.batch, self.seq_len, total_dim)
def test_constant_input_normalization(self):
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
dtype=torch.float32)
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
expected_q = torch.full((1, 1, self.q_size), norm_val)
expected_k = torch.full((1, 1, self.kv_size), norm_val)
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
self.assertTrue(torch.equal(v, expected_v))

View File

@@ -21,6 +21,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.tp_size = 1
self.moe_config.ep_size = 1
self.moe_config.dp_group = MagicMock()
self.moe_config.original_num_experts = 8
@patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
@@ -196,7 +197,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# After all-gather with DP=2, should double the batch size
@@ -265,7 +265,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
router_logits,
rm_router_logits=False,
gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2]

View File

@@ -24,8 +24,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture):
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config)
@@ -104,24 +103,24 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None
)), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \
@@ -252,196 +251,6 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
pass
class TestAscendFusedMoe:
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = AscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config)
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = AscendFusedMoE(**error_config)
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
moe = AscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert moe.quant_method == mock_quant_method
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = AscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = AscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
patch('vllm_ascend.utils.is_310p', return_value=False):
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]])

View File

@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
hidden_states, router_logits, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
hidden_states, router_logits, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None)
hidden_states, router_logits, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")

View File

@@ -45,10 +45,6 @@ def register_model():
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(

View File

@@ -62,7 +62,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):

View File

@@ -1,263 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm_ascend.ops.fused_moe import AscendFusedMoE
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states,
attn_metadata=None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)
return hidden_states
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
if not self.use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CustomQwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
SupportsPP.__init__(self)
SupportsLoRA.__init__(self)
MixtureOfExperts.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0

View File

@@ -18,7 +18,6 @@
import torch
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa

View File

@@ -15,7 +15,7 @@
# limitations under the License.
#
import os.path
from typing import Callable, Optional
from typing import Any, Callable, Optional
import torch
import torch_npu
@@ -23,6 +23,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
@@ -37,99 +38,110 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
def __init__(self, moe: FusedMoEConfig = None):
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
self.transpose = True
super().__init__(moe=moe)
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not vllm_config.model_config.enforce_eager)
self.transpose = True
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb)
class AscendFusedMoE(FusedMoE):
@@ -138,8 +150,26 @@ class AscendFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
num_experts = kwargs["num_experts"]
intermediate_size = kwargs["intermediate_size"]
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
self.global_num_experts = num_experts
self.expert_map = None
self.log2phy = None
self.global_redundant_expert_num = 0
if self.quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(
self.moe_config)
else:
self.quant_method = self.quant_config.get_quant_method(
self, self.layer_name)
assert self.quant_method is not None
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
@@ -148,6 +178,7 @@ class AscendFusedMoE(FusedMoE):
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
@@ -180,6 +211,25 @@ class AscendFusedMoE(FusedMoE):
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts
self.moe_config.original_num_experts = num_experts
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": self.hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": self.params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
@@ -210,11 +260,18 @@ class AscendFusedMoE(FusedMoE):
router_logits: torch.Tensor):
assert self.quant_method is not None
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
forward_context = get_forward_context()
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled)
replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
@@ -233,11 +290,13 @@ class AscendFusedMoE(FusedMoE):
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
shared_experts=None,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
@@ -361,8 +420,3 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -1,455 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Any, Callable, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEConfig # isort: skip
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = None
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
need_trans=True,
dynamic_eplb=self.dynamic_eplb)
class AscendFusedMoE(FusedMoE):
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = -1
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
if params_dtype is None:
params_dtype = torch.get_default_dtype()
vllm_config = get_current_vllm_config()
self.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size
if dp_size is not None else get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config)
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
self.log2phy = None
self.global_redundant_expert_num = 0
is_deepseek_v3_r1 = self.global_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
self.all_reduce_merge = get_all_reduce_merge_state(
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(self.expert_map != -1)
if self.expert_map is not None else num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
)
self.moe_config = moe
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
self.moe_load = None
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.ep_group = get_ep_group()
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):
assert self.quant_method is not None
if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
if shared_experts:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
if forward_context.sp_enabled:
replace_allreduce = True
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
enable_shared_expert_dp=self.enable_shared_expert_dp,
rm_router_logits=self.rm_router_logits,
replace_allreduce=replace_allreduce,
gate=gate)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=None,
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
group_list_type = None
if shared_experts:
if isinstance(e_hidden_states,
tuple) and len(e_hidden_states) == 2:
e_hidden_states, shared_hidden_states = e_hidden_states
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3:
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
if self.dynamic_eplb and group_list_type is not None:
self.moe_load += expert_tokens if group_list_type else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=e_hidden_states,
reduce_results=(not self.all_reduce_merge))
if shared_experts:
return final_hidden_states, shared_hidden_states
else:
return final_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
real_top_k,
enable_force_load_balance: bool = False,
):
hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
)
return hidden_states

View File

@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC):
"""
@@ -41,13 +43,16 @@ class FusedMoEPrepareAndFinalize(ABC):
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_config.ep_size, self.moe_config.dp_size,
is_deepseek_v3_r1)
@abstractmethod
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@@ -61,7 +66,6 @@ class FusedMoEPrepareAndFinalize(ABC):
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
rm_router_logits (bool): Discard input router_logits and recompute via gate
replace_allreduce (bool): Bypass default all-reduce behavior
gate (nn.Module, optional): Gate network to recompute router_logits if needed
@@ -116,7 +120,6 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@@ -215,7 +218,6 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@@ -294,7 +296,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@@ -302,7 +303,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
1. Fetch max token count across DP group from forward context.
2. Pad local tensors to that size.
3. All-gather across DP group to form global input tensor.
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
@@ -318,14 +318,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
if not rm_router_logits:
if not self.rm_router_logits:
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
if rm_router_logits:
if self.rm_router_logits:
router_logits, _ = gate(hidden_states) # Recompute globally
else:
router_logits = self.moe_config.dp_group.all_gather(
@@ -399,14 +399,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
@@ -418,7 +416,7 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu)
if rm_router_logits:
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self._naive_multicast(

View File

@@ -67,12 +67,11 @@ class MoECommMethod(ABC):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
rm_router_logits, replace_allreduce, gate)
replace_allreduce, gate)
self.mc2_mask = mc2_mask
return hidden_states, router_logits

View File

@@ -468,9 +468,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
self.num_experts = self.num_experts + self.num_global_redundant_experts
self.hidden_shape = None
self.topk_weights = None

View File

@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)

View File

@@ -53,9 +53,9 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -81,7 +81,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
self.experts = TorchairAscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,

View File

@@ -880,7 +880,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state