# SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from typing import Any, Optional import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version class MoEDispatcherConfig: def __init__(self): self.num_local_experts: int = 0 self.num_moe_experts: int = 0 self.moe_pad_expert_input_to_capacity: bool = False self.moe_expert_capacity_factor: Optional[float] = None self.moe_router_topk: int = 2 self.moe_grouped_gemm: bool = False self.group_topk: int = 0 self.num_groups: int = 1 self.expert_bias: torch.Tensor = None self.scaling_factor: Optional[float] = None self.is_fused: bool = True def set_num_local_experts(self, num_local_experts): self.num_local_experts = num_local_experts return self def set_num_moe_experts(self, num_moe_experts): self.num_moe_experts = num_moe_experts return self def set_moe_pad_expert_input_to_capacity(self, moe_pad_expert_input_to_capacity): self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity return self def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor): self.moe_expert_capacity_factor = moe_expert_capacity_factor return self def set_moe_router_topk(self, moe_router_topk): self.moe_router_topk = moe_router_topk return self def set_moe_grouped_gemm(self, moe_grouped_gemm): self.moe_grouped_gemm = moe_grouped_gemm return self def set_group_topk(self, group_topk): self.group_topk = group_topk return self def set_num_groups(self, num_groups): self.num_groups = num_groups return self def set_expert_bias(self, expert_bias): self.expert_bias = expert_bias return self def set_scaling_factor(self, scaling_factor): self.scaling_factor = scaling_factor return self def set_is_fused(self, is_fused): self.is_fused = is_fused return self def build(self): return self class MoEDispatcher: def __init__(self, config: MoEDispatcherConfig) -> None: """ Initialize the MoE Token Dispatcher. """ self.config = config self.shared_experts = None def set_shared_experts(self, shared_experts): self.shared_experts = shared_experts @property def ep_group(self): """Get expert model parallel group.""" return get_ep_group().device_group @property def ep_rank(self): return get_ep_group().rank_in_group @property def ep_size(self): return get_ep_group().world_size @property def tp_ep_group(self): """Get expert tensor and model parallel group.""" return None @property def tp_ep_size(self): return 1 class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): overlap_stream = None """ The implementation of the AlltoAll-based token dispatcher, which handles token dispatching on the sequence level instead of token level. The core of this implementation lies in each device dispatching on the entire sequence, with the hidden state being partitioned. """ def __init__(self, config: MoEDispatcherConfig): """ Initialize the AlltoAllSeq token dispatcher. Args: config (MoEDispatcherConfig): Configuration for the transformer model. """ super().__init__(config) self.num_local_experts = config.num_local_experts self.config = config # use MOEAlltoAllSEQTokenDispatcher to init self.hidden_shape = None self.num_input_tokens = None self.num_experts = config.num_moe_experts assert self.num_local_experts > 0, "Expected at least one expert" if self.num_local_experts > 1: self.expert_ids_per_ep_rank = torch.tensor( [i % self.num_local_experts for i in range(self.num_experts)], dtype=torch.int32, device=torch.npu.current_device(), ) local_expert_indices_offset = (self.ep_rank * self.num_local_experts) self.local_expert_indices = [ local_expert_indices_offset + i for i in range(self.num_local_experts) ] assert (len(self.local_expert_indices) == self.num_local_experts ), "Invalid local expert indices" for i in range(len(self.local_expert_indices) - 1): assert (self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1), "local_expert_indices must be continuous" self.probs = None self.input_splits = None self.output_splits = None self.routing_map = None self.hidden_shape_before_permute = None # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent # to each local expert by all ranks. self.num_global_tokens_per_local_expert_cpu = None self.num_global_tokens_per_local_expert = None # A cuda stream synchronization is needed in self.token_permutation() # in some cases, because there are several non-blocking DtoH data # transfers called in self.preprocess(). The synchronization happens # at different points based on MoE settings as late as possible. # Valid sync points are "before_permutation_1", "before_ep_alltoall", # "before_finish", and "no_sync". self.device_sync_point = "no_sync" # cached intermediate tensors. self.cached_permutated_local_input_tokens = None self.cached_global_input_tokens = None self.cached_shared_expert_output = None self.tokens_per_expert = None self.perm1_finish_event = None self.global_input_tokens_local_experts_indices = None if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: """ Preprocess routing map for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the routing map. It also initializes the necessary data structures for AlltoAll communication, such as input and output splits, and the mapping between global tokens and local experts. Args: routing_map (torch.Tensor): The mapping of tokens to experts, with shape [num_tokens, num_experts]. Returns: torch.Tensor: Tensor containing the number of tokens assigned to local expert. """ num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts) # num_local_tokens_per_expert: [num_experts] ep_size = self.ep_size # Dropless self.num_out_tokens = indices.numel() if self.ep_size > 1 or self.num_local_experts > 1: # Token dropless and enable ep. A synchronization is needed before expert parallel # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. self.device_sync_point = "before_ep_alltoall" else: # Token dropless and no ep. A synchronization is needed to get the # `tokens_per_expert` CPU value. self.device_sync_point = "before_finish" if ep_size > 1: # =================================================== # Calculate input_splits, output_splits for alltoall-v. # =================================================== self.input_splits = (num_local_tokens_per_expert.reshape( ep_size, self.num_local_experts).sum(axis=1).to( torch.device("cpu"), non_blocking=True).numpy()) num_global_tokens_per_expert = gather_from_sequence_parallel_region( num_local_tokens_per_expert, group=self.ep_group).reshape(ep_size, self.num_experts) self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ 0]:self.local_expert_indices[-1] + 1] if self.num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before sum." ) self.output_splits = (self.num_global_tokens_per_local_expert.sum( axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( axis=0) # =================================================== # num_global_tokens_per_expert: [ep_size, num_experts] # num_global_tokens_per_local_expert: [ep_size, num_local_experts] # num_tokens_per_local_expert: [num_local_experts] # =================================================== else: self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( -1, self.num_experts) num_tokens_per_local_expert = num_local_tokens_per_expert if self.num_local_experts > 1 and with_sync: if self.num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before operations." ) self.device_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) return num_tokens_per_local_expert def token_permutation( self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor, ): """ Dispatch tokens to local experts using AlltoAllSeq communication. Args: hidden_states (torch.Tensor): Input token embeddings. probs (torch.Tensor): Probs of tokens assigned to experts. Shape: [num_tokens, num_experts]. routing_map (torch.Tensor): Mapping of tokens assigned to experts. Shape: [num_tokens, num_experts]. Returns: Tuple[torch.Tensor, torch.Tensor]: - Permuted token embeddings for local experts. - Number of tokens per expert. """ self.hidden_shape = hidden_states.shape self.probs = probs self.top_indices = routing_map assert probs.dim() == 2, "Expected 2D tensor for probs" assert routing_map.dim() == 2, "Expected 2D tensor for routing map" # Permutation 1: input to AlltoAll input def alltoall_token_permutation1(hidden_states, routing_map): assert self.hidden_shape is not None hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(routing_map) if self.tp_ep_size > 1: hidden_states = all_to_all_sp2hp(hidden_states, group=self.tp_ep_group) self.hidden_shape_before_permute = hidden_states.shape if self.device_sync_point == "before_permutation_1": torch.npu.current_stream().synchronize() permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( tokens=hidden_states, indices=self.top_indices, num_out_tokens=self.num_out_tokens, ) return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( hidden_states, routing_map) self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping # permute 1 ep_group = self.ep_group # Perform expert parallel AlltoAll communication if self.device_sync_point == "before_ep_alltoall": torch.npu.current_stream().synchronize() _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( permutated_local_input_tokens, self.output_splits, self.input_splits, ep_group, ) # shared experts compute if self.shared_experts is not None: (share_experts_output), *_ = self.shared_experts(hidden_states) else: share_experts_output = None permute1_ep_all_to_all_handle.wait() permutated_local_input_tokens.untyped_storage().resize_(0) def alltoall_token_permutation2(global_input_tokens): # Permutation 2: Sort tokens by local expert. if self.num_local_experts > 1: global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, self.global_input_tokens_local_experts_indices) # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( global_input_tokens, self.tp_ep_group) if self.device_sync_point == "before_finish": torch.npu.current_stream().synchronize() return global_input_tokens # token premute2 input global_input_tokens = alltoall_token_permutation2(global_input_tokens) return share_experts_output, global_input_tokens, tokens_per_expert def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): """ Reverse the token permutation to restore the original order. Args: hidden_states (torch.Tensor): Output from local experts. bias (torch.Tensor, optional): Bias tensor (not supported). Returns: Tuple[torch.Tensor, Optional[torch.Tensor]]: - Unpermuted token embeddings in the original order. - None (bias is not supported). """ def alltoall_token_unpermutation1(hidden_states): assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" # Perform tensor parallel Reduce-Scatter # hidden_states: [SEQL, H] -> [SEQL, H/TP] if self.tp_ep_size > 1: hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region( hidden_states, group=self.tp_ep_group) # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: hidden_states = torch_npu.npu_moe_token_unpermute( hidden_states, self.reversed_global_input_permutation_mapping) return hidden_states hidden_states = alltoall_token_unpermutation1(hidden_states) ep_group = self.ep_group # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] _, permutated_local_input_tokens, handle = async_all_to_all( hidden_states, self.input_splits, self.output_splits, ep_group) handle.wait() hidden_states.untyped_storage().resize_(0) def alltoall_token_unpermutation2(permutated_local_input_tokens): # Unpermutation 1: AlltoAll output to output output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, sorted_indices=self.reversed_local_input_permutation_mapping. to(torch.int32), probs=self.probs, restore_shape=self.hidden_shape_before_permute) # Perform tensor parallel AlltoAll communication # output: [S*B, H/TP] -> [S*B/TP, H] if self.tp_ep_size > 1: output = all_to_all_hp2sp(output, self.tp_ep_group) # Reshape the output tensor output = output.view(self.hidden_shape) return output output = alltoall_token_unpermutation2(permutated_local_input_tokens) self.input_splits = None self.output_splits = None self.num_global_tokens_per_local_expert = None self.num_global_tokens_per_local_expert_cpu = None return output, None class MoETokenDispatcher(ABC): def __init__(self, **kwargs) -> None: """ Initialize the MoE Token Dispatcher. """ self.top_k = kwargs.get("top_k") self.num_experts = kwargs.get("num_experts") @property def ep_group(self): """Get expert model parallel group.""" return get_ep_group().device_group @property def ep_rank(self): return get_ep_group().rank_in_group @property def ep_size(self): return get_ep_group().world_size @abstractmethod def token_permutation( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, shared_experts: Optional[Any] = None, ): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): raise NotImplementedError("Restore function not implemented.") class TokenDispatcherWithMC2(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.ep_rank_id = get_mc2_group().rank_in_group self.ep_world_size = get_mc2_group().world_size ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.with_quant = kwargs.get("with_quant") self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or self.torchair_graph_enabled) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine self.a3_need_extra_args = \ get_ascend_soc_version() == AscendSocVersion.A3 self.output = None self.dynamic_scale = None self.assist_info_for_combine = None self.ep_recv_counts = None self.shared_act = None self.topk_ids = None self.topk_weights = None self.shared_experts = None def get_permute_mc2_kwargs(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, global_redundant_expert_num: int = 0): quant_mode = 0 forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask if self.with_quant: if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num else: moe_expert_num = global_redundant_expert_num else: moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, } stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, } if self.need_extra_args: stage1_kwargs.update({ "group_tp": self.moe_all_to_all_group_name, "tp_world_size": 1, "tp_rank_id": 0, }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage1_kwargs.update({ "x_active_mask": mc2_mask, }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 def token_permutation( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, shared_experts: Optional[Any] = None, ): self.expert_map = expert_map self.topk_ids = topk_ids self.topk_weights = topk_weights self.shared_experts = shared_experts kwargs_mc2 = self.get_permute_mc2_kwargs(hidden_states, topk_weights, topk_ids, expert_map, global_redundant_expert_num) self.output = torch_npu.npu_moe_distribute_dispatch_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, self.dynamic_scale, self.assist_info_for_combine, \ expert_token_nums, self.ep_recv_counts = self.output[0:5] if self.with_quant: if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(shared_gate_up, expand_x) shared_act_out = shared_experts.act_fn( (shared_gate_up, shared_dequant_scale)) self.shared_act, self.swiglu_out_scale = \ shared_act_out[0], shared_act_out[1] else: if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(hidden_states, topk_weights) shared_gate_up, _ = shared_experts.gate_up_proj( hidden_states) npu_wait_tensor(shared_gate_up, expand_x) self.shared_act = shared_experts.act_fn(shared_gate_up) group_list_type = 1 return group_list_type, expand_x, expert_token_nums def get_unpermute_mc_kwargs(self, hidden_states: torch.Tensor): assert self.expert_map is not None assert self.topk_weights is not None assert self.topk_ids is not None assert self.output is not None moe_expert_num = len(self.expert_map) forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask # moeCombine kwargs_mc2 = { "expand_x": hidden_states, "expert_ids": self.topk_ids, "expert_scales": self.topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, } if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) else: tp_recv_counts = self.output[5] stage3_kwargs = { "ep_send_counts": self.ep_recv_counts, "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, } if self.enable_dispatch_v2: stage3_kwargs.update({ "assist_info_for_combine": self.assist_info_for_combine, }) else: stage3_kwargs.update({ "expand_idx": self.assist_info_for_combine, }) if self.need_extra_args: stage3_kwargs.update({ "tp_send_counts": tp_recv_counts, "group_tp": self.moe_all_to_all_group_name, "tp_world_size": 1, "tp_rank_id": 0, }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage3_kwargs.update({ "x_active_mask": mc2_mask, }) kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): kwargs_mc2 = self.get_unpermute_mc_kwargs(hidden_states) hidden_states = torch_npu.npu_moe_distribute_combine_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( **kwargs_mc2) if self.shared_experts is None: return hidden_states else: if self.with_quant: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(self.shared_act, hidden_states) shared_hidden_states, _ = self.shared_experts.down_proj( (self.shared_act, self.swiglu_out_scale)) else: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(self.shared_act, hidden_states) shared_hidden_states, _ = self.shared_experts.down_proj( self.shared_act) return hidden_states, shared_hidden_states class TokenDispatcherWithAllGather(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_router_weight_on_input = kwargs.get( "apply_router_weight_on_input") self.top_k = kwargs.get("top_k") self.max_num_tokens = kwargs.get("max_num_tokens") ep_size = kwargs.get("ep_size") if ep_size is not None: self.num_experts_local = self.num_experts // ep_size self.with_quant = kwargs.get("with_quant") self.sorted_weights = None self.expanded_row_idx = None self.sorted_token_indices = None self.original_shape = None self.mask = None self.expert_map = None self.topk_weights = None self.topk_ids = None def token_permutation( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, shared_experts: Optional[Any] = None, ): self.original_shape = hidden_states.shape # assert len(original_shape) == 2 num_tokens = hidden_states.shape[:-1].numel() dtype = hidden_states.dtype device = hidden_states.device self.expert_map = expert_map self.topk_weights = topk_weights self.topk_ids = topk_ids # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfsloat16 are supported" if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1 ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) if expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange( num_tokens, device=device, dtype=torch.int64).unsqueeze(1).expand(-1, self.top_k).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) experts_flat = topk_ids.view(-1) local_experts_flat = expert_map[experts_flat] # Filter valid token-expert pairs self.mask = local_experts_flat != -1 filtered_weights = torch.where( self.mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) filtered_experts = torch.where( self.mask, local_experts_flat, torch.full_like(local_experts_flat, self.num_experts_local)).to(topk_ids.dtype) # Sort by local expert IDs sort_indices = torch.argsort(filtered_experts.view(torch.float32)) self.sorted_token_indices = token_indices[sort_indices] self.sorted_weights = filtered_weights[sort_indices] # Compute token counts with minlength of num_experts # This is equivalent to but faster than: # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] token_counts = torch.zeros(self.num_experts_local + 1, device=device, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) token_counts = token_counts[:self.num_experts_local] # Rearrange hidden_states sorted_hidden_states = hidden_states[self.sorted_token_indices] if self.with_quant: group_list_type = 1 else: expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) group_list_type = 0 else: row_idx_len = num_tokens * self.top_k row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, device=device).view(self.top_k, -1).permute( 1, 0).contiguous()) active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=active_num) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, self.num_experts_local) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 return group_list_type, sorted_hidden_states, expert_tokens def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): assert self.mask is not None assert self.sorted_token_indices is not None assert self.sorted_weights is not None assert self.original_shape is not None dtype = hidden_states.dtype device = hidden_states.device if self.expert_map is not None: weighted_down_out = hidden_states * \ self.sorted_weights.unsqueeze(1) final_hidden_states = torch.zeros(*self.original_shape, device=hidden_states.device, dtype=hidden_states.dtype) # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] # This created multiple NaN and index_add_ will mix them up which harms accuracy # remove this mask and filter after it being fixed num_valid_tokens = self.mask.sum() valid_token_mask = torch.arange( 0, self.sorted_token_indices.shape[0], device=device).unsqueeze(1) < num_valid_tokens valid_output = torch.where( valid_token_mask, weighted_down_out, torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, self.sorted_token_indices, valid_output) else: if self.with_quant: final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, skip2=None, bias=None, scales=self.topk_weights, expanded_src_to_dst_row=self.expanded_row_idx, export_for_source_row=self.topk_ids, ) if len(self.original_shape) == 3: final_hidden_states = final_hidden_states.view( self.original_shape) else: scales = torch.ones_like( self.topk_weights ) if self.apply_router_weight_on_input else self.topk_weights # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, skip2=None, bias=None, scales=scales, expanded_src_to_dst_row=self.expanded_row_idx, export_for_source_row=self.topk_ids, ) return final_hidden_states # mypy: disable-error-code="override" class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): def __init__(self, **kwargs): super(MoETokenDispatcher, self).__init__(**kwargs) self.apply_router_weight_on_input = kwargs.get( "apply_router_weight_on_input") ep_size = kwargs.get("ep_size") self.local_ep = ep_size self.top_k = kwargs.get("top_k") assert self.local_ep is not None self.local_num_experts = self.num_experts // self.local_ep self.local_num_group = self.top_k // self.local_ep self.bsz = None def token_permutation( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, shared_experts: Optional[Any] = None, ): if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1 ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) self.sorted_hidden_states = hidden_states.index_select( 0, self.sorted_topk_ids // self.local_num_group) experts_id = torch.arange(0, self.local_num_experts, dtype=topk_ids.dtype, device=topk_ids.device) num_tokens_per_expert = ( flatten_topk_ids.unsqueeze(-1) == experts_id).to( torch.float32).sum(0) self.topk_scales = topk_weights.view(-1).index_select( 0, self.sorted_topk_ids).unsqueeze(-1) group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) return hidden_states, group_list def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): assert self.local_ep is not None unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( torch.int32) unsorted_hidden_states = hidden_states.index_select( 0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( self.bsz, self.top_k // self.local_ep, -1).sum(1) return final_hidden_states