2025-09-09 09:40:35 +08:00
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Any , Callable , Optional
import torch
import torch_npu
from vllm . config import get_current_vllm_config
2025-10-14 10:38:28 +08:00
from vllm . distributed import get_tensor_model_parallel_world_size
2025-09-09 09:40:35 +08:00
from vllm . distributed . parallel_state import ( get_dp_group , get_ep_group ,
get_tp_group )
from vllm . forward_context import get_forward_context
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEConfig # isort: skip
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEParallelConfig # isort: skip
from vllm . model_executor . layers . fused_moe . layer import (
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
from vllm . model_executor . layers . quantization . base_config import \
QuantizationConfig
from vllm_ascend . ascend_config import get_ascend_config
from vllm_ascend . distributed . parallel_state import get_mc2_group
2025-10-14 10:38:28 +08:00
from vllm_ascend . eplb . core . eplb_utils import ( determine_default_expert_map ,
determine_default_log2phy_map )
2025-09-09 09:40:35 +08:00
from vllm_ascend . ops . expert_load_balancer import ExpertLoadBalancer
2025-10-14 10:38:28 +08:00
from vllm_ascend . ops . moe . experts_selector import select_experts
from vllm_ascend . ops . moe . moe_comm_method import setup_moe_comm_method
from vllm_ascend . utils import ( ACL_FORMAT_FRACTAL_NZ ,
2025-09-09 09:40:35 +08:00
get_all_reduce_merge_state ,
2025-10-14 10:38:28 +08:00
get_rm_router_logits_state , is_310p ,
vllm_version_is )
2025-09-09 09:40:35 +08:00
class AscendUnquantizedFusedMoEMethod ( UnquantizedFusedMoEMethod ) :
def __init__ ( self , moe : FusedMoEConfig = None ) :
super ( ) . __init__ ( moe = moe )
vllm_config = get_current_vllm_config ( )
self . global_batch_size = vllm_config . scheduler_config . max_num_seqs
self . max_model_len = vllm_config . model_config . max_model_len
get_ascend_config ( )
2025-10-14 10:38:28 +08:00
self . dynamic_eplb = get_ascend_config ( ) . dynamic_eplb
2025-09-09 09:40:35 +08:00
try :
device_group = get_mc2_group ( ) . device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch . distributed . get_rank ( group = device_group )
backend = device_group . _get_backend ( torch . device ( " npu " ) )
self . moe_all_to_all_group_name = backend . get_hccl_comm_name (
local_rank )
except AttributeError :
self . moe_all_to_all_group_name = None
def process_weights_after_loading ( self , layer ) :
super ( UnquantizedFusedMoEMethod ,
self ) . process_weights_after_loading ( layer )
layer . w13_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w13_weight . data ) ,
requires_grad = False )
layer . w2_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w2_weight . data ) ,
requires_grad = False )
if not is_310p ( ) :
layer . w13_weight . data = torch_npu . npu_format_cast (
layer . w13_weight . data , ACL_FORMAT_FRACTAL_NZ )
layer . w2_weight . data = torch_npu . npu_format_cast (
layer . w2_weight . data , ACL_FORMAT_FRACTAL_NZ )
def apply (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
router_logits : torch . Tensor ,
top_k : int ,
renormalize : bool ,
use_grouped_topk : bool = False ,
global_num_experts : int = - 1 ,
expert_map : Optional [ torch . Tensor ] = None ,
topk_group : Optional [ int ] = None ,
num_expert_group : Optional [ int ] = None ,
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
is_prefill : bool = False ,
enable_force_load_balance : bool = False ,
shared_experts : Optional [ Any ] = None ,
* * kwargs ,
) - > torch . Tensor :
topk_weights , topk_ids , row_idx = select_experts (
hidden_states = x ,
router_logits = router_logits ,
top_k = top_k ,
use_grouped_topk = use_grouped_topk ,
renormalize = renormalize ,
topk_group = topk_group ,
num_expert_group = num_expert_group ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
global_num_experts = global_num_experts )
topk_weights = topk_weights . to ( x . dtype )
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self . use_aclgraph :
topk_ids = torch . randint_like ( topk_ids , 0 , global_num_experts )
2025-10-14 10:38:28 +08:00
moe_comm_method = get_forward_context ( ) . moe_comm_method
return moe_comm_method . fused_experts (
hidden_states = x ,
w1 = layer . w13_weight ,
w2 = layer . w2_weight ,
topk_weights = topk_weights ,
topk_ids = topk_ids ,
row_idx = row_idx ,
global_num_experts = global_num_experts ,
expert_map = expert_map ,
shared_experts = shared_experts ,
need_trans = True ,
dynamic_eplb = self . dynamic_eplb )
2025-09-09 09:40:35 +08:00
class AscendFusedMoE ( FusedMoE ) :
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = - 1
def __init__ (
self ,
num_experts : int , # Global number of experts
top_k : int ,
hidden_size : int ,
intermediate_size : int ,
params_dtype : Optional [ torch . dtype ] = None ,
reduce_results : bool = False ,
renormalize : bool = True ,
use_grouped_topk : bool = False ,
num_expert_group : Optional [ int ] = None ,
topk_group : Optional [ int ] = None ,
quant_config : Optional [ QuantizationConfig ] = None ,
tp_size : Optional [ int ] = None ,
ep_size : Optional [ int ] = None ,
dp_size : Optional [ int ] = None ,
prefix : str = " " ,
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
activation : str = " silu " ,
apply_router_weight_on_input : bool = False ,
) :
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super ( ) . __init__ (
num_experts = num_experts ,
top_k = top_k ,
hidden_size = hidden_size ,
intermediate_size = intermediate_size ,
params_dtype = params_dtype ,
reduce_results = reduce_results ,
renormalize = renormalize ,
use_grouped_topk = use_grouped_topk ,
num_expert_group = num_expert_group ,
topk_group = topk_group ,
quant_config = quant_config ,
tp_size = tp_size ,
ep_size = ep_size ,
dp_size = dp_size ,
prefix = prefix ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
activation = activation ,
apply_router_weight_on_input = apply_router_weight_on_input ,
)
AscendFusedMoE . moe_counter + = 1
self . moe_instance_id = AscendFusedMoE . moe_counter
if params_dtype is None :
params_dtype = torch . get_default_dtype ( )
vllm_config = get_current_vllm_config ( )
self . moe_parallel_config = FusedMoEParallelConfig . make (
tp_size_ = ( tp_size if tp_size is not None else
get_tensor_model_parallel_world_size ( ) ) ,
dp_size_ = ( dp_size
if dp_size is not None else get_dp_group ( ) . world_size ) ,
vllm_parallel_config = vllm_config . parallel_config )
self . top_k = top_k
self . num_experts = num_experts
self . global_num_experts = num_experts
assert intermediate_size % self . tp_size == 0
self . intermediate_size_per_partition = intermediate_size / / self . tp_size
self . reduce_results = reduce_results
self . renormalize = renormalize
self . use_grouped_topk = use_grouped_topk
if self . use_grouped_topk :
assert num_expert_group is not None and topk_group is not None
self . num_expert_group = num_expert_group
self . topk_group = topk_group
self . custom_routing_function = custom_routing_function
self . scoring_func = scoring_func
self . e_score_correction_bias = e_score_correction_bias
self . expert_map = None
self . activation = activation
self . log2phy = None
self . global_redundant_expert_num = 0
is_deepseek_v3_r1 = self . global_num_experts == 256
self . rm_router_logits = get_rm_router_logits_state (
self . moe_parallel_config . ep_size , self . dp_size , is_deepseek_v3_r1 )
self . all_reduce_merge = get_all_reduce_merge_state (
self . moe_parallel_config . ep_size , is_deepseek_v3_r1 )
ascend_config = get_ascend_config ( )
2025-10-14 10:38:28 +08:00
self . dynamic_eplb = ascend_config . dynamic_eplb
self . expert_map_path = ascend_config . expert_map_path
self . global_redundant_expert_num = ascend_config . init_redundancy_expert
self . global_num_experts = num_experts + self . global_redundant_expert_num
# static eplb initializing with expert_map_path
if self . expert_map_path and os . path . exists (
self . expert_map_path ) and os . access ( self . expert_map_path ,
os . R_OK ) :
self . expert_load_balancer = ExpertLoadBalancer (
self . expert_map_path , self . global_num_experts )
self . local_num_experts , self . expert_map = (
self . expert_load_balancer . get_rank_placement_map (
self . moe_instance_id , self . ep_rank ) )
self . log2phy = self . expert_load_balancer . get_rank_log2phy_map (
self . moe_instance_id , self . ep_rank ) . npu ( )
self . global_redundant_expert_num = (
self . expert_load_balancer . get_global_redundant_expert_num ( ) )
2025-09-09 09:40:35 +08:00
else :
2025-10-14 10:38:28 +08:00
# init moe.
2025-09-09 09:40:35 +08:00
self . local_num_experts , self . expert_map = determine_expert_map (
2025-10-14 10:38:28 +08:00
self . ep_size , self . ep_rank , self . global_num_experts )
# dynamic eplb initializing with not expert_map_path
if self . dynamic_eplb :
self . global_redundant_expert_num = ascend_config . init_redundancy_expert
self . local_num_experts , self . expert_map = determine_default_expert_map (
self . global_num_experts , self . ep_size , self . ep_rank ,
self . global_redundant_expert_num )
self . log2phy = determine_default_log2phy_map (
self . global_num_experts , self . ep_size , self . ep_rank ,
self . global_redundant_expert_num )
local_num_experts = ( torch . sum ( self . expert_map != - 1 )
if self . expert_map is not None else num_experts )
if self . dynamic_eplb :
self . moe_load = torch . zeros ( local_num_experts , dtype = torch . int64 )
2025-09-09 09:40:35 +08:00
self . enable_shared_expert_dp = ascend_config . enable_shared_expert_dp
if self . scoring_func != " softmax " and not self . use_grouped_topk :
raise ValueError ( " Only softmax scoring function is supported for "
" non-grouped topk. " )
2025-10-14 10:38:28 +08:00
if vllm_version_is ( " 0.10.2 " ) :
moe = FusedMoEConfig . make (
num_experts = self . global_num_experts ,
experts_per_token = top_k ,
hidden_dim = hidden_size ,
num_local_experts = self . local_num_experts ,
moe_parallel_config = self . moe_parallel_config ,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype = params_dtype ,
quant_config = quant_config )
else :
moe = FusedMoEConfig (
num_experts = self . global_num_experts ,
experts_per_token = top_k ,
hidden_dim = hidden_size ,
num_local_experts = self . local_num_experts ,
moe_parallel_config = self . moe_parallel_config ,
in_dtype = params_dtype ,
)
2025-09-09 09:40:35 +08:00
self . moe_config = moe
2025-10-14 10:38:28 +08:00
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
2025-09-09 09:40:35 +08:00
if quant_config is None :
self . quant_method = AscendUnquantizedFusedMoEMethod ( moe )
else :
self . quant_method = quant_config . get_quant_method ( self , prefix )
assert self . quant_method is not None
local_num_experts = torch . sum ( self . expert_map != - 1 ) \
if self . expert_map is not None else num_experts
2025-10-14 10:38:28 +08:00
self . moe_load = None
if self . dynamic_eplb :
self . moe_load = torch . zeros ( local_num_experts , dtype = torch . int64 )
2025-09-09 09:40:35 +08:00
moe_quant_params = {
" num_experts " : local_num_experts ,
" hidden_size " : hidden_size ,
" intermediate_size_per_partition " :
self . intermediate_size_per_partition ,
" params_dtype " : params_dtype ,
" weight_loader " : self . weight_loader ,
}
# need full intermediate size pre-sharding for WNA16 act order
if ( self . quant_method . __class__ . __name__
in ( " GPTQMarlinMoEMethod " , " CompressedTensorsWNA16MoEMethod " ) ) :
moe_quant_params [ " intermediate_size_full " ] = intermediate_size
self . ep_group = get_ep_group ( )
# NOTE: self.tp_group is not expert_tp_group
self . tp_group = get_tp_group ( ) . device_group
self . quant_method . create_weights ( layer = self , * * moe_quant_params )
2025-10-14 10:38:28 +08:00
self . moe_config . tp_group = get_tp_group ( )
self . moe_config . dp_group = get_dp_group ( )
self . moe_config . ep_group = get_ep_group ( )
self . moe_config . mc2_group = get_mc2_group ( )
self . moe_config . num_global_redundant_experts = self . global_redundant_expert_num
setup_moe_comm_method ( self . moe_config )
def update_expert_map ( self , new_expert_map ) :
self . expert_map = new_expert_map
def get_map ( self ) :
return self . expert_map
def get_log2phy_map ( self ) :
return self . logical_to_physical_map
def clear_moe_load ( self ) :
if self . moe_load is not None :
self . moe_load . zero_ ( )
2025-09-09 09:40:35 +08:00
def forward ( self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
enable_force_load_balance : bool = False ,
top_k : Optional [ int ] = None ,
shared_experts : Optional [ Any ] = None ,
gate = None ,
2025-10-14 10:38:28 +08:00
replace_allreduce : bool = False ) :
2025-09-09 09:40:35 +08:00
assert self . quant_method is not None
if top_k :
real_top_k = top_k
else :
real_top_k = self . top_k
forward_context = get_forward_context ( )
mc2_mask = forward_context . mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share , dynamic_scale_for_share = None , None
if shared_experts :
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts ( hidden_states )
2025-10-14 10:38:28 +08:00
if forward_context . sp_enabled :
2025-09-09 09:40:35 +08:00
replace_allreduce = True
2025-10-14 10:38:28 +08:00
hidden_states , router_logits = forward_context . moe_comm_method . prepare (
hidden_states = hidden_states ,
router_logits = router_logits ,
enable_shared_expert_dp = self . enable_shared_expert_dp ,
rm_router_logits = self . rm_router_logits ,
replace_allreduce = replace_allreduce ,
gate = gate )
2025-09-09 09:40:35 +08:00
# Matrix multiply.
e_hidden_states = self . quant_method . apply (
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
global_num_experts = self . global_num_experts ,
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
is_prefill = is_prefill ,
enable_force_load_balance = enable_force_load_balance ,
log2phy = self . log2phy ,
global_redundant_expert_num = self . global_redundant_expert_num ,
shared_experts = None ,
mc2_mask = mc2_mask ,
quantized_x_for_share = quantized_x_for_share ,
dynamic_scale_for_share = dynamic_scale_for_share ,
)
2025-10-14 10:38:28 +08:00
group_list_type = None
2025-09-09 09:40:35 +08:00
if shared_experts :
2025-10-14 10:38:28 +08:00
if isinstance ( e_hidden_states ,
tuple ) and len ( e_hidden_states ) == 2 :
2025-09-09 09:40:35 +08:00
e_hidden_states , shared_hidden_states = e_hidden_states
2025-10-14 10:38:28 +08:00
if isinstance ( e_hidden_states , tuple ) and len ( e_hidden_states ) == 3 :
e_hidden_states , group_list_type , expert_tokens = e_hidden_states
if self . dynamic_eplb and group_list_type is not None :
self . moe_load + = expert_tokens if group_list_type else \
torch . cat ( [ expert_tokens [ : 1 ] , expert_tokens [ 1 : ] - expert_tokens [ : - 1 ] ] )
final_hidden_states = forward_context . moe_comm_method . finalize (
hidden_states = e_hidden_states ,
reduce_results = ( not self . all_reduce_merge ) )
2025-09-09 09:40:35 +08:00
if shared_experts :
return final_hidden_states , shared_hidden_states
else :
return final_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp (
self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
real_top_k ,
enable_force_load_balance : bool = False ,
) :
hidden_states = self . quant_method . apply (
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
global_num_experts = self . global_num_experts ,
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
is_prefill = is_prefill ,
enable_force_load_balance = enable_force_load_balance ,
)
return hidden_states