refactor select_experts of moe module (#2150)

### What this PR does / why we need it?
this pr refactor select_experts of moe module
i merge implementations of quantitative and non-quantitative method in a
new class
use such as vllm like ExpertsSelector.select_experts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
test in qwen3-moe and all ut.

- vLLM version: v0.10.0
- vLLM main:
e18859298d

Signed-off-by: yangcheng <yangcheng104@huawei.com>
Co-authored-by: yangcheng (AJ) <y00806874@china.huawei.com>
This commit is contained in:
shiyuan680
2025-08-14 11:50:53 +08:00
committed by GitHub
parent 103654ccd6
commit e14f2ef669
10 changed files with 359 additions and 370 deletions

View File

@@ -26,7 +26,8 @@ import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
from vllm_ascend.ops.fused_moe import fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
@@ -142,7 +143,7 @@ def test_select_experts(
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)

View File

@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
adapt_patch(True)
@@ -389,3 +390,28 @@ class TestAscendUnquantizedFusedMoEMethod:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
use_grouped_topk=False,
renormalize=True,
topk_group=None,
num_expert_group=None,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
global_num_experts=global_num_experts)
assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)

View File

@@ -5,12 +5,13 @@ import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, fused_experts_310p,
native_grouped_topk,
quant_per_tensor, select_experts)
quant_per_tensor)
class TestQuantPerTensor(TestBase):
@@ -772,7 +773,7 @@ class TestSelectExperts(TestBase):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
@@ -868,9 +869,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
with patch('torch.topk',
return_value=(None, expected_topk_indices)) as mock_topk:
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
mock_topk.assert_called_once()
@@ -885,9 +886,9 @@ class TestNativeGroupedTopkPartialMock(TestBase):
expected_topk_indices = torch.tensor([[0], [1]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
expected_result = torch.tensor(
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
@@ -900,7 +901,7 @@ class TestNativeGroupedTopkPartialMock(TestBase):
expected_topk_indices = torch.tensor([[0], [0]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
self.assertTrue(result.numel() > 0)

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
unified_fused_experts)
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -59,7 +59,7 @@ def forward_oot(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -69,7 +69,6 @@ def forward_oot(
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -80,7 +79,7 @@ def forward_oot(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None

View File

@@ -46,6 +46,7 @@ from vllm_ascend.distributed.communication_op import \
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
@@ -920,143 +921,6 @@ def fused_experts(
return final_hidden_states
def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights
if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
if not use_grouped_topk and custom_routing_function is None:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
@@ -1111,36 +975,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
**kwargs,
) -> torch.Tensor:
is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
is_unquantized=True)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as

View File

View File

@@ -0,0 +1,269 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
import torch
import torch_npu
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
is_unquantized: bool = False,
global_num_experts: int = -1):
"""
Fused experts with select experts.
Args:
router_logits: router logits of shape (num_tokens, hidden_size).
hidden_states: Hidden states of shape (num_tokens, hidden_size).
top_k: number of top k experts.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
indices_type: dtype of indices
is_unquantized: Whether the data are unquantized.
global_num_experts: Global number of experts.
Returns:
topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
topk_weights, topk_ids = _select_experts_with_fusion_ops(
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
global_num_experts=global_num_experts,
is_unquantized=is_unquantized)
if topk_weights is None:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights
def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int],
renormalize: bool, top_k: int, num_expert_group: Optional[int],
e_score_correction_bias: Optional[torch.Tensor]):
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _select_experts_with_fusion_ops(
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
global_num_experts: int = -1,
is_unquantized: bool = False):
topk_weights, topk_ids = None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _native_select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
return _select_expert_use_group_topk(
topk_weights=topk_weights,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias)
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

View File

@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
@@ -245,34 +245,18 @@ class AscendW4A8DynamicFusedMoEMethod:
1] == global_num_experts, "Number of global experts mismatch"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
@@ -314,7 +298,7 @@ class AscendW4A8DynamicFusedMoEMethod:
mc2_mask=kwargs.get("mc2_mask", None))
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(

View File

@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
@@ -251,8 +252,7 @@ class AscendW8A8FusedMoEMethod:
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
global_num_experts=global_num_experts)
if is_310p():
return fused_experts_310p(hidden_states=x,
@@ -645,123 +645,3 @@ def fused_experts(
"currently does not support tensor parallelism")
return final_hidden_states
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts=-1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
elif custom_routing_function is None:
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights

View File

@@ -27,7 +27,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
@@ -903,36 +903,18 @@ class AscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
@@ -995,7 +977,7 @@ class AscendW8A8DynamicFusedMoEMethod:
expert_map=expert_map)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(