[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)

### What this PR does / why we need it?
1. Move additional functionalities from fused_moe.py to
common_fused_moe.py and remove fused_moe.py
2. Remove unnecessary custom classes from qwen3_moe.py, and it will be
completely removed after we release vllm-ascend v0.11.0

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:

1. Enable/Disable EP
3. Aclgraph & eager
4. SP


- vLLM version: v0.11.0

---------

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-10-09 14:12:46 +08:00
committed by GitHub
parent a36e3da78e
commit 94dd832815
17 changed files with 175 additions and 1110 deletions

View File

@@ -96,7 +96,7 @@ def mock_distributed():
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \ patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group", patch("vllm_ascend.models.deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \ return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,

View File

@@ -1,68 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import unittest
import torch
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
class DummyRMSNorm:
def __init__(self, dim: int, eps: float = 1e-6):
self.dim = dim
self.eps = eps
def __call__(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
denom = (mean_sq + self.eps).sqrt()
return x / denom
class TestCustomQwen3MoeAttention(unittest.TestCase):
def setUp(self):
self.batch = 2
self.seq_len = 3
self.q_size = 8
self.kv_size = 8
self.head_dim = 4
self.rms_eps = 1e-6
total_dim = self.q_size + 2 * self.kv_size
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
dtype=torch.float32).reshape(
self.batch, self.seq_len, total_dim)
def test_constant_input_normalization(self):
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
dtype=torch.float32)
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
expected_q = torch.full((1, 1, self.q_size), norm_val)
expected_k = torch.full((1, 1, self.kv_size), norm_val)
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
self.assertTrue(torch.equal(v, expected_v))

View File

@@ -21,6 +21,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.moe_config.tp_size = 1 self.moe_config.tp_size = 1
self.moe_config.ep_size = 1 self.moe_config.ep_size = 1
self.moe_config.dp_group = MagicMock() self.moe_config.dp_group = MagicMock()
self.moe_config.original_num_experts = 8
@patch( @patch(
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
@@ -196,7 +197,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
h_out, r_out, _ = layer.prepare(hidden_states, h_out, r_out, _ = layer.prepare(hidden_states,
router_logits, router_logits,
rm_router_logits=False,
gate=mock_gate) gate=mock_gate)
# After all-gather with DP=2, should double the batch size # After all-gather with DP=2, should double the batch size
@@ -265,7 +265,6 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
# Run prepare # Run prepare
h_out, r_out, _ = layer.prepare(hidden_states, h_out, r_out, _ = layer.prepare(hidden_states,
router_logits, router_logits,
rm_router_logits=False,
gate=mock_gate) gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2] # Should be global tensor: [7, 8] and [7, 2]

View File

@@ -24,8 +24,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch from vllm_ascend.utils import AscendSocVersion, adapt_patch
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture):
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
mock_vllm_config.model_config.max_model_len = 2048 mock_vllm_config.model_config.max_model_len = 2048
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
return_value=mock_vllm_config) return_value=mock_vllm_config)
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
return_value=mock_vllm_config) return_value=mock_vllm_config)
@@ -104,24 +103,24 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \ return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config', patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
return_value=MagicMock( return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False), torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False, enable_multistream_moe=False,
expert_map_path=None expert_map_path=None
)), \ )), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map', patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context', patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
return_value=mock_forward_context_obj), \ return_value=mock_forward_context_obj), \
@@ -252,196 +251,6 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
pass pass
class TestAscendFusedMoe:
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = AscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config)
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = AscendFusedMoE(**error_config)
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
moe = AscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert moe.quant_method == mock_quant_method
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = AscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = AscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
patch('vllm_ascend.utils.is_310p', return_value=False):
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
forward_context = mock_dist_env['mock_forward_context_obj']
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
mock_moe_comm_method.fused_experts.assert_called_once()
expected_shape = (16, 2)
assert result.shape == expected_shape
class TestExpertsSelector: class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]]) @pytest.mark.parametrize("global_num_experts", [[256], [128]])

View File

@@ -63,7 +63,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None) hidden_states, router_logits, False, False, None)
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, reduce_results=True) comm_impl.finalize(h_out, reduce_results=True)
@@ -108,7 +108,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None) hidden_states, router_logits, False, False, None)
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, reduce_results=True) comm_impl.finalize(h_out, reduce_results=True)
@@ -153,7 +153,7 @@ class TestMoECommMethod(TestBase):
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, False, None) hidden_states, router_logits, False, False, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")

View File

@@ -45,10 +45,6 @@ def register_model():
"DeepSeekMTPModel", "DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model( ModelRegistry.register_model(

View File

@@ -62,7 +62,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules, from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer) AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
class CustomDeepseekV2RowParallelLinear(RowParallelLinear): class CustomDeepseekV2RowParallelLinear(RowParallelLinear):

View File

@@ -1,263 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm_ascend.ops.fused_moe import AscendFusedMoE
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states,
attn_metadata=None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)
return hidden_states
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
if not self.use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CustomQwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
SupportsPP.__init__(self)
SupportsLoRA.__init__(self)
MixtureOfExperts.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0

View File

@@ -18,7 +18,6 @@
import torch import torch
import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa

View File

@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
# #
import os.path import os.path
from typing import Callable, Optional from typing import Any, Callable, Optional
import torch import torch
import torch_npu import torch_npu
@@ -23,6 +23,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
@@ -37,99 +38,110 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def unquantized_fused_moe_init_func(self, *args, **kwargs): def __init__(self, moe: FusedMoEConfig = None):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
# NOTE: Currently, this self.use_aclgraph is only used in super().__init__(moe=moe)
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
self.transpose = True
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not vllm_config.model_config.enforce_eager)
self.transpose = True
def forward_oot( def process_weights_after_loading(self, layer):
self, super(UnquantizedFusedMoEMethod,
layer: torch.nn.Module, self).process_weights_after_loading(layer)
x: torch.Tensor, if self.transpose:
use_grouped_topk: bool, w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
top_k: int, 1, 2).contiguous()
router_logits: torch.Tensor, layer.w13_weight = torch.nn.Parameter(w13_data,
renormalize: bool, requires_grad=False)
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts( w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
hidden_states=x, 1, 2).contiguous()
router_logits=router_logits, layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
moe_comm_method = get_forward_context().moe_comm_method self.transpose = False
return moe_comm_method.fused_experts(hidden_states=x, else:
w1=layer.w13_weight, w13_data = self._maybe_pad_weight(layer.w13_weight.data)
w2=layer.w2_weight, layer.w13_weight = torch.nn.Parameter(w13_data,
topk_weights=topk_weights, requires_grad=False)
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
def process_weights_after_loading(self, layer): if not is_310p():
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) layer.w13_weight.data = torch_npu.npu_format_cast(
if self.transpose: layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( layer.w2_weight.data = torch_npu.npu_format_cast(
1, 2).contiguous() layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( def apply(self,
1, 2).contiguous() layer: torch.nn.Module,
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
self.transpose = False topk_weights, topk_ids, row_idx = select_experts(
else: hidden_states=x,
w13_data = self._maybe_pad_weight(layer.w13_weight.data) router_logits=router_logits,
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
w2_data = self._maybe_pad_weight(layer.w2_weight.data) topk_weights = topk_weights.to(x.dtype)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) # this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
if not is_310p(): moe_comm_method = get_forward_context().moe_comm_method
layer.w13_weight.data = torch_npu.npu_format_cast( return moe_comm_method.fused_experts(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) hidden_states=x,
layer.w2_weight.data = torch_npu.npu_format_cast( w1=layer.w13_weight,
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb)
class AscendFusedMoE(FusedMoE): class AscendFusedMoE(FusedMoE):
@@ -138,8 +150,26 @@ class AscendFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
num_experts = kwargs["num_experts"]
intermediate_size = kwargs["intermediate_size"]
AscendFusedMoE.moe_counter += 1 AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter self.moe_instance_id = AscendFusedMoE.moe_counter
self.global_num_experts = num_experts
self.expert_map = None
self.log2phy = None
self.global_redundant_expert_num = 0
if self.quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(
self.moe_config)
else:
self.quant_method = self.quant_config.get_quant_method(
self, self.layer_name)
assert self.quant_method is not None
self.moe_config.tp_group = get_tp_group() self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group() self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group() self.moe_config.ep_group = get_ep_group()
@@ -148,6 +178,7 @@ class AscendFusedMoE(FusedMoE):
self.dynamic_eplb = ascend_config.dynamic_eplb self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path # static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists( if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path, self.expert_map_path) and os.access(self.expert_map_path,
@@ -180,6 +211,25 @@ class AscendFusedMoE(FusedMoE):
if self.dynamic_eplb: if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts
self.moe_config.original_num_experts = num_experts
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": self.hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": self.params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
setup_moe_comm_method(self.moe_config) setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map): def update_expert_map(self, new_expert_map):
@@ -210,11 +260,18 @@ class AscendFusedMoE(FusedMoE):
router_logits: torch.Tensor): router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
forward_context = get_forward_context()
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context() forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled) replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
@@ -233,11 +290,13 @@ class AscendFusedMoE(FusedMoE):
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb, quantized_x_for_share=quantized_x_for_share,
expert_load_view=self.expert_load_view, dynamic_scale_for_share=dynamic_scale_for_share,
logical_to_physical_map=self.logical_to_physical_map, shared_experts=None,
logical_replica_count=self.logical_replica_count, enable_force_load_balance=enable_force_load_balance,
) log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num)
if isinstance(final_hidden_states, tuple): if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states final_hidden_states, group_list_type, expert_tokens = final_hidden_states
@@ -361,8 +420,3 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if self.multistream_overlap_shared_expert: if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream) torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_output return shared_out, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -1,455 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Any, Callable, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEConfig # isort: skip
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = None
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
need_trans=True,
dynamic_eplb=self.dynamic_eplb)
class AscendFusedMoE(FusedMoE):
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = -1
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
if params_dtype is None:
params_dtype = torch.get_default_dtype()
vllm_config = get_current_vllm_config()
self.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size
if dp_size is not None else get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config)
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
self.log2phy = None
self.global_redundant_expert_num = 0
is_deepseek_v3_r1 = self.global_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
self.all_reduce_merge = get_all_reduce_merge_state(
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(self.expert_map != -1)
if self.expert_map is not None else num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
)
self.moe_config = moe
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
self.moe_load = None
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.ep_group = get_ep_group()
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):
assert self.quant_method is not None
if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
if shared_experts:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
if forward_context.sp_enabled:
replace_allreduce = True
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
enable_shared_expert_dp=self.enable_shared_expert_dp,
rm_router_logits=self.rm_router_logits,
replace_allreduce=replace_allreduce,
gate=gate)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=None,
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
group_list_type = None
if shared_experts:
if isinstance(e_hidden_states,
tuple) and len(e_hidden_states) == 2:
e_hidden_states, shared_hidden_states = e_hidden_states
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3:
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
if self.dynamic_eplb and group_list_type is not None:
self.moe_load += expert_tokens if group_list_type else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=e_hidden_states,
reduce_results=(not self.all_reduce_merge))
if shared_experts:
return final_hidden_states, shared_hidden_states
else:
return final_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
real_top_k,
enable_force_load_balance: bool = False,
):
hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
)
return hidden_states

View File

@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
@@ -41,13 +43,16 @@ class FusedMoEPrepareAndFinalize(ABC):
def __init__(self, moe_config: FusedMoEConfig): def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config self.moe_config = moe_config
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_config.ep_size, self.moe_config.dp_size,
is_deepseek_v3_r1)
@abstractmethod @abstractmethod
def prepare(self, def prepare(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -61,7 +66,6 @@ class FusedMoEPrepareAndFinalize(ABC):
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size] hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts enable_shared_expert_dp (bool): Skip DP communication for shared experts
rm_router_logits (bool): Discard input router_logits and recompute via gate
replace_allreduce (bool): Bypass default all-reduce behavior replace_allreduce (bool): Bypass default all-reduce behavior
gate (nn.Module, optional): Gate network to recompute router_logits if needed gate (nn.Module, optional): Gate network to recompute router_logits if needed
@@ -116,7 +120,6 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -215,7 +218,6 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -294,7 +296,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -302,7 +303,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
1. Fetch max token count across DP group from forward context. 1. Fetch max token count across DP group from forward context.
2. Pad local tensors to that size. 2. Pad local tensors to that size.
3. All-gather across DP group to form global input tensor. 3. All-gather across DP group to form global input tensor.
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None) Tuple of (global_hidden_states, global_router_logits, None)
@@ -318,14 +318,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
if pad_size > 0: if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states, hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
if not rm_router_logits: if not self.rm_router_logits:
router_logits = nn.functional.pad(router_logits, router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
# All-gather across DP group # All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather( hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0) hidden_states, 0)
if rm_router_logits: if self.rm_router_logits:
router_logits, _ = gate(hidden_states) # Recompute globally router_logits, _ = gate(hidden_states) # Recompute globally
else: else:
router_logits = self.moe_config.dp_group.all_gather( router_logits = self.moe_config.dp_group.all_gather(
@@ -399,14 +399,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Preparation steps: Preparation steps:
1. Fetch cumulative token boundaries from forward context. 1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors. 2. Multicast hidden_states and router_logits to form global tensors.
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None) Tuple of (global_hidden_states, global_router_logits, None)
@@ -418,7 +416,7 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
).dp_metadata.cu_tokens_across_sp(1) ).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states, hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu) self.cu_tokens_across_dp_cpu)
if rm_router_logits: if self.rm_router_logits:
router_logits, _ = gate(hidden_states) router_logits, _ = gate(hidden_states)
else: else:
router_logits = self._naive_multicast( router_logits = self._naive_multicast(

View File

@@ -67,12 +67,11 @@ class MoECommMethod(ABC):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]: gate=None) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, hidden_states, router_logits, enable_shared_expert_dp,
rm_router_logits, replace_allreduce, gate) replace_allreduce, gate)
self.mc2_mask = mc2_mask self.mc2_mask = mc2_mask
return hidden_states, router_logits return hidden_states, router_logits

View File

@@ -468,9 +468,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
super().__init__(**kwargs) super().__init__(**kwargs)
self.with_quant = False self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0) self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
self.num_experts = self.num_experts + self.num_global_redundant_experts
self.hidden_shape = None self.hidden_shape = None
self.topk_weights = None self.topk_weights = None

View File

@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable) oproj_tp_enable)

View File

@@ -53,9 +53,9 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp) init_metadata_for_sp)
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -81,7 +81,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
self.experts = AscendFusedMoE( self.experts = TorchairAscendFusedMoE(
num_experts=config.num_experts, num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,

View File

@@ -880,7 +880,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# this is a naive implementation for experts load balance so as # this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank. # to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs. # currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph: if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state fused_moe_state = get_forward_context().fused_moe_state