# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py from typing import Callable, Optional import torch import torch_npu from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, ) -> torch.Tensor: """ Fused experts with top-k routing. Args: hidden_states: Hidden states of shape (num_tokens, hidden_size). w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). topk_weights: Routing weights of shape (num_tokens, top_k). topk_ids: Selected expert IDs of shape (num_tokens, top_k). top_k: Number of experts to select. expert_map: Expert mapping of shape (num_experts,). Returns: hidden_states: Hidden states after routing. """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" original_shape = hidden_states.shape assert len(original_shape) == 2 num_tokens = hidden_states.shape[:-1].numel() num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device assert dtype in [torch.float32, torch.float16, torch.bfloat16 ], "Only float32, float16, and bfloat16 are supported" if expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange(num_tokens, device=device, dtype=torch.int64).unsqueeze(1).expand( -1, top_k).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) experts_flat = topk_ids.view(-1) local_experts_flat = expert_map[experts_flat] # Filter valid token-expert pairs mask = local_experts_flat != -1 filtered_weights = torch.where( mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) filtered_experts = torch.where( mask, local_experts_flat, torch.full_like(local_experts_flat, num_experts)).to(topk_ids.dtype) # Sort by local expert IDs sort_indices = torch.argsort(filtered_experts) sorted_token_indices = token_indices[sort_indices] sorted_weights = filtered_weights[sort_indices] # Compute token counts with minlength of num_experts # This is equivalent to but faster than: # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] token_counts = torch.zeros(num_experts + 1, device=device, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) token_counts = token_counts[:num_experts] expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) # Rearrange hidden_states sorted_hidden_states = hidden_states[sorted_token_indices] else: row_idx_len = num_tokens * top_k row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[sorted_hidden_states], weight=[w1], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, ) # TODO: Remove this in the future. gate_up_out = torch.cat(gate_up_out_list, dim=0) gate_up_out = torch_npu.npu_swiglu(gate_up_out) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, ) down_out_list = torch.cat(down_out_list, dim=0) if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) final_hidden_states = torch.zeros(*original_shape, device=hidden_states.device, dtype=dtype) final_hidden_states.index_add_(0, sorted_token_indices, weighted_down_out) # TODO: This should not happen! Look into it! # fill nan with 0.0 final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( down_out_list, skip1=None, skip2=None, bias=None, scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) return final_hidden_states def native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: Optional[int], topk_group: Optional[int], ): topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1] topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask.scatter_(1, topk_group_indices, 1) topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) return topk_weights def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. Args: hidden_states: Hidden states of shape (num_tokens, hidden_size). router_logits: Router logits of shape (num_tokens, num_experts). top_k: Number of experts to select. use_grouped_topk: Whether to group experts before selecting top-k. renormalize: Whether to renormalize the routing weights. topk_group: Number of expert groups to select from. num_expert_group: Number of experts in each group. custom_routing_function: Custom routing function. scoring_func: Scoring function to use. e_score_correction_bias: Correction bias to apply to expert scores. Returns: topk_weights: Routing weights of shape (num_tokens, top_k). topk_ids: Selected expert IDs of shape (num_tokens, top_k). Raises: ValueError: If an unsupported scoring function is provided. """ assert hidden_states.shape[0] == router_logits.shape[0], ( "Number of tokens mismatch") if custom_routing_function is not None: raise NotImplementedError( "Custom routing function is not supported now") if scoring_func == "softmax": # NOTE: vLLM use dtype=torch.float here topk_weights = router_logits.softmax(dim=-1) elif scoring_func == "sigmoid": topk_weights = router_logits.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_weights = topk_weights topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) # TODO: Change to npu_group_topk when the latest CANN and NNAL is available # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) topk_weights = native_grouped_topk(topk_weights, num_expert_group, topk_group) if e_score_correction_bias is not None: topk_ids = torch.topk(topk_weights, k=top_k, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk(topk_weights, k=top_k, dim=-1, sorted=False) else: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids def forward_oot( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, **kwargs, ): assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ) return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) UnquantizedFusedMoEMethod.forward_oot = forward_oot