# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from collections.abc import Callable import torch from vllm_ascend.utils import get_weight_prefetch_method def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: int | None = None, num_expert_group: int | None = None, custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, e_score_correction_bias: torch.Tensor | None = None, indices_type: torch.dtype | None = None, global_num_experts: int = -1, ): """ Fused experts with select experts. Args: router_logits: router logits of shape (num_tokens, hidden_size). hidden_states: Hidden states of shape (num_tokens, hidden_size). top_k: number of top k experts. use_grouped_topk: Whether to group experts before selecting top-k. renormalize: Whether to renormalize the routing weights. topk_group: Number of expert groups to select from. num_expert_group: Number of experts in each group. custom_routing_function: Custom routing function. scoring_func: Scoring function to use. e_score_correction_bias: Correction bias to apply to expert scores. indices_type: dtype of indices global_num_experts: Global number of experts. Returns: topk_weights: router weights of shape (num_tokens, top_k). topk_ids: selected expert IDs of shape (num_tokens, top_k). """ # prefetch w1_w3_proj.weight preprocess weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( hidden_states=hidden_states, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, custom_routing_function=custom_routing_function, ) if is_support_npu_moe_gating_top_k: topk_weights, topk_ids = _select_experts_with_fusion_ops( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, topk_group=topk_group, renormalize=renormalize, e_score_correction_bias=e_score_correction_bias, num_expert_group=num_expert_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, global_num_experts=global_num_experts, ) else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) return topk_weights, topk_ids def check_npu_moe_gating_top_k( hidden_states: torch.Tensor, top_k: int, renormalize: bool, topk_group: int | None = None, num_expert_group: int | None = None, scoring_func: str = "softmax", custom_routing_function: Callable | None = None, ): if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch return False if custom_routing_function is not None: return False if scoring_func != "softmax" and scoring_func != "sigmoid": return False topk_group = topk_group if topk_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1 if not ( num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group == 0 and hidden_states.shape[-1] // num_expert_group > 2 ): return False if topk_group < 1 or topk_group > num_expert_group: return False if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): return False if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103 return False return True def _native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: int | None, topk_group: int | None, ): topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1] topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask.scatter_(1, topk_group_indices, 1) topk_weight_mask = ( topk_group_mask.unsqueeze(-1) .expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group) .reshape(num_token, -1) ) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) return topk_weights def _renormalize_topk_weights( topk_weights: torch.Tensor, renormalize: bool, ): if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights def _select_expert_use_group_topk( topk_weights: torch.Tensor, topk_group: int | None, renormalize: bool, top_k: int, num_expert_group: int | None, e_score_correction_bias: torch.Tensor | None, ): assert topk_group is not None assert num_expert_group is not None if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_weights = topk_weights topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) # TODO: Change to npu_group_topk when the latest CANN and NNAL is available # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group) # TODO bfloat16 is not supported in torch.topk with ge graph. if e_score_correction_bias is not None: topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False) topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids def _select_experts_with_fusion_ops( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, e_score_correction_bias: torch.Tensor | None, topk_group: int | None, num_expert_group: int | None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1, ): topk_group = topk_group if topk_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1 renorm = int(renormalize) norm_type = 0 if scoring_func == "softmax" else 1 if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype) topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k( router_logits, k=top_k, k_group=topk_group, group_count=num_expert_group, group_select_mode=1, renorm=renorm, norm_type=norm_type, # 0: softmax; 1: sigmoid out_flag=False, routed_scaling_factor=routed_scaling_factor, eps=1e-20, bias_opt=e_score_correction_bias, ) return topk_weights, topk_ids def _native_select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: int | None = None, num_expert_group: int | None = None, custom_routing_function: Callable | None = None, scoring_func: str = "softmax", e_score_correction_bias: torch.Tensor | None = None, global_num_experts: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. Args: hidden_states: Hidden states of shape (num_tokens, hidden_size). router_logits: Router logits of shape (num_tokens, num_experts). top_k: Number of experts to select. use_grouped_topk: Whether to group experts before selecting top-k. renormalize: Whether to renormalize the routing weights. topk_group: Number of expert groups to select from. num_expert_group: Number of experts in each group. custom_routing_function: Custom routing function. scoring_func: Scoring function to use. e_score_correction_bias: Correction bias to apply to expert scores. Returns: topk_weights: Routing weights of shape (num_tokens, top_k). topk_ids: Selected expert IDs of shape (num_tokens, top_k). Raises: ValueError: If an unsupported scoring function is provided. """ if scoring_func == "softmax": topk_weights = router_logits.softmax(dim=-1) elif scoring_func == "sigmoid": topk_weights = router_logits.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") if use_grouped_topk: return _select_expert_use_group_topk( topk_weights=topk_weights, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, e_score_correction_bias=e_score_correction_bias, ) if custom_routing_function is not None: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, global_num_experts=global_num_experts, ) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) return topk_weights, topk_ids topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids def zero_experts_compute( expert_indices: torch.Tensor, expert_scales: torch.Tensor, num_experts: int, zero_expert_type: str, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts zero_expert_scales = expert_scales.clone() zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales) hidden_states = hidden_states.unsqueeze(1) zero_expert_scales = zero_expert_scales.unsqueeze(2) result = hidden_states * zero_expert_scales result = result.sum(dim=1) normal_expert_mask = expert_indices >= num_experts expert_indices = torch.where(normal_expert_mask, 0, expert_indices) expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales) return expert_indices, expert_scales, result