Files
xc-llm-ascend/vllm_ascend/ops/fused_moe/experts_selector.py
Nengjun Ma 66b60c9440 [Refact]Refact MLA/SFA weight prefetch to consist with moe weight prefetch (#6629)
### What this PR does / why we need it?
1. [Refact] Refact MLA/SFA weight prefetch to consist with moe weight
prefetch
2. Remove duplicated o_proj weight prefetch in forward for MLA/SFA

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?

1) Performance result:
Perf test data:
*) MLA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 11.9669 token/s | 12.0287 token/s |
11.9978 |
| o_proj no duplicate prefetch | 12.5594 token/s | 12.6216 token/s |
12.5905 | 4.94%| |

single layer performace improve: 5%~8%

*) SFA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 13.0523 token/s | 13.1084 token/s |
13.08035 | |
| o_proj no duplicate prefetch | 13.9844 token/s | 14.1678 token/s |
14.0761 | 7.6% |

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
2026-02-10 14:14:37 +08:00

333 lines
13 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
import torch
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
global_num_experts: int = -1,
):
"""
Fused experts with select experts.
Args:
router_logits: router logits of shape (num_tokens, hidden_size).
hidden_states: Hidden states of shape (num_tokens, hidden_size).
top_k: number of top k experts.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
indices_type: dtype of indices
global_num_experts: Global number of experts.
Returns:
topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
custom_routing_function=custom_routing_function,
)
if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts,
)
else:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids
def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
scoring_func: str = "softmax",
custom_routing_function: Callable | None = None,
):
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
return False
if custom_routing_function is not None:
return False
if scoring_func != "softmax" and scoring_func != "sigmoid":
return False
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (
num_expert_group > 0
and hidden_states.shape[-1] % num_expert_group == 0
and hidden_states.shape[-1] // num_expert_group > 2
):
return False
if topk_group < 1 or topk_group > num_expert_group:
return False
if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
return False
return True
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: int | None,
topk_group: int | None,
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (
topk_group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights
def _select_expert_use_group_topk(
topk_weights: torch.Tensor,
topk_group: int | None,
renormalize: bool,
top_k: int,
num_expert_group: int | None,
e_score_correction_bias: torch.Tensor | None,
):
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
topk_group: int | None,
num_expert_group: int | None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1,
):
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
renorm = int(renormalize)
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
router_logits,
k=top_k,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1,
renorm=renorm,
norm_type=norm_type, # 0: softmax; 1: sigmoid
out_flag=False,
routed_scaling_factor=routed_scaling_factor,
eps=1e-20,
bias_opt=e_score_correction_bias,
)
return topk_weights, topk_ids
def _native_select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
return _select_expert_use_group_topk(
topk_weights=topk_weights,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
)
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def zero_experts_compute(
expert_indices: torch.Tensor,
expert_scales: torch.Tensor,
num_experts: int,
zero_expert_type: str,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)
hidden_states = hidden_states.unsqueeze(1)
zero_expert_scales = zero_expert_scales.unsqueeze(2)
result = hidden_states * zero_expert_scales
result = result.sum(dim=1)
normal_expert_mask = expert_indices >= num_experts
expert_indices = torch.where(normal_expert_mask, 0, expert_indices)
expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales)
return expert_indices, expert_scales, result