2025-02-21 17:07:37 +08:00
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2025-03-11 21:08:02 +08:00
# Copyright 2023 The vLLM team.
2025-02-21 17:07:37 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-04-17 14:59:56 +08:00
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
2025-02-21 17:07:37 +08:00
2025-06-09 19:28:11 +08:00
import os
2025-08-28 10:13:35 +08:00
from typing import Any , Callable , Optional
2025-02-21 17:07:37 +08:00
import torch
import torch_npu
2025-04-19 17:38:18 +08:00
from vllm . config import get_current_vllm_config
2025-09-24 11:29:59 +08:00
from vllm . distributed import get_tensor_model_parallel_world_size
2025-07-21 09:08:04 +08:00
from vllm . distributed . parallel_state import ( get_dp_group , get_ep_group ,
get_tp_group )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
from vllm . forward_context import get_forward_context
2025-07-09 08:52:24 +08:00
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEConfig # isort: skip
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEParallelConfig # isort: skip
2025-04-19 17:38:18 +08:00
from vllm . model_executor . layers . fused_moe . layer import (
2025-07-03 18:36:17 +08:00
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
2025-05-28 21:18:41 +08:00
from vllm . model_executor . layers . quantization . base_config import \
QuantizationConfig
2025-04-19 17:38:18 +08:00
2025-06-05 16:28:01 +08:00
from vllm_ascend . ascend_config import get_ascend_config
2025-07-28 14:06:20 +08:00
from vllm_ascend . distributed . parallel_state import get_mc2_group
2025-09-17 10:36:43 +08:00
from vllm_ascend . eplb . core . eplb_utils import ( determine_default_expert_map ,
determine_default_log2phy_map )
2025-06-09 19:28:11 +08:00
from vllm_ascend . ops . expert_load_balancer import ExpertLoadBalancer
2025-09-08 20:09:50 +08:00
from vllm_ascend . ops . moe . experts_selector import select_experts
2025-09-22 19:12:58 +08:00
from vllm_ascend . ops . moe . moe_comm_method import setup_moe_comm_method
2025-09-16 11:06:00 +08:00
from vllm_ascend . utils import ( ACL_FORMAT_FRACTAL_NZ ,
2025-08-28 10:13:35 +08:00
get_all_reduce_merge_state ,
2025-09-20 17:37:57 +08:00
get_rm_router_logits_state , is_310p ,
vllm_version_is )
2025-04-19 17:38:18 +08:00
2025-08-28 10:13:35 +08:00
2025-04-19 17:38:18 +08:00
class AscendUnquantizedFusedMoEMethod ( UnquantizedFusedMoEMethod ) :
2025-07-03 18:36:17 +08:00
def __init__ ( self , moe : FusedMoEConfig = None ) :
2025-05-28 21:18:41 +08:00
super ( ) . __init__ ( moe = moe )
2025-04-19 17:38:18 +08:00
vllm_config = get_current_vllm_config ( )
self . global_batch_size = vllm_config . scheduler_config . max_num_seqs
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
self . max_model_len = vllm_config . model_config . max_model_len
2025-08-26 14:12:43 +08:00
get_ascend_config ( )
2025-09-17 10:36:43 +08:00
self . dynamic_eplb = get_ascend_config ( ) . dynamic_eplb
2025-06-04 18:31:41 +08:00
2025-04-19 17:38:18 +08:00
try :
2025-07-28 14:06:20 +08:00
device_group = get_mc2_group ( ) . device_group
2025-04-19 17:38:18 +08:00
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch . distributed . get_rank ( group = device_group )
backend = device_group . _get_backend ( torch . device ( " npu " ) )
self . moe_all_to_all_group_name = backend . get_hccl_comm_name (
local_rank )
except AttributeError :
self . moe_all_to_all_group_name = None
def process_weights_after_loading ( self , layer ) :
super ( UnquantizedFusedMoEMethod ,
self ) . process_weights_after_loading ( layer )
layer . w13_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w13_weight . data ) ,
requires_grad = False )
layer . w2_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w2_weight . data ) ,
requires_grad = False )
2025-08-27 11:25:02 +08:00
if not is_310p ( ) :
layer . w13_weight . data = torch_npu . npu_format_cast (
layer . w13_weight . data , ACL_FORMAT_FRACTAL_NZ )
layer . w2_weight . data = torch_npu . npu_format_cast (
layer . w2_weight . data , ACL_FORMAT_FRACTAL_NZ )
2025-04-19 17:38:18 +08:00
def apply (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
router_logits : torch . Tensor ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
top_k : int ,
2025-04-19 17:38:18 +08:00
renormalize : bool ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
use_grouped_topk : bool = False ,
2025-04-19 17:38:18 +08:00
global_num_experts : int = - 1 ,
expert_map : Optional [ torch . Tensor ] = None ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
topk_group : Optional [ int ] = None ,
num_expert_group : Optional [ int ] = None ,
2025-04-19 17:38:18 +08:00
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
2025-05-24 14:29:36 +08:00
is_prefill : bool = False ,
2025-06-04 20:26:44 +08:00
enable_force_load_balance : bool = False ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
shared_experts : Optional [ Any ] = None ,
2025-04-19 17:38:18 +08:00
* * kwargs ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
) - > torch . Tensor :
2025-08-27 09:13:31 +08:00
topk_weights , topk_ids , row_idx = select_experts (
2025-08-14 11:50:53 +08:00
hidden_states = x ,
router_logits = router_logits ,
top_k = top_k ,
use_grouped_topk = use_grouped_topk ,
renormalize = renormalize ,
topk_group = topk_group ,
num_expert_group = num_expert_group ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
2025-09-03 09:14:17 +08:00
global_num_experts = global_num_experts )
2025-04-23 16:23:25 +08:00
2025-06-04 20:26:44 +08:00
topk_weights = topk_weights . to ( x . dtype )
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
2025-08-04 15:23:20 +08:00
if enable_force_load_balance and not self . use_aclgraph :
2025-06-04 20:26:44 +08:00
topk_ids = torch . randint_like ( topk_ids , 0 , global_num_experts )
2025-09-16 11:06:00 +08:00
moe_comm_method = get_forward_context ( ) . moe_comm_method
return moe_comm_method . fused_experts (
hidden_states = x ,
w1 = layer . w13_weight ,
w2 = layer . w2_weight ,
topk_weights = topk_weights ,
topk_ids = topk_ids ,
row_idx = row_idx ,
global_num_experts = global_num_experts ,
expert_map = expert_map ,
shared_experts = shared_experts ,
2025-09-17 10:36:43 +08:00
need_trans = True ,
dynamic_eplb = self . dynamic_eplb )
2025-04-19 17:38:18 +08:00
class AscendFusedMoE ( FusedMoE ) :
2025-06-09 19:28:11 +08:00
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = - 1
2025-05-16 12:14:55 +08:00
def __init__ (
self ,
num_experts : int , # Global number of experts
top_k : int ,
hidden_size : int ,
intermediate_size : int ,
params_dtype : Optional [ torch . dtype ] = None ,
reduce_results : bool = False ,
renormalize : bool = True ,
use_grouped_topk : bool = False ,
num_expert_group : Optional [ int ] = None ,
topk_group : Optional [ int ] = None ,
quant_config : Optional [ QuantizationConfig ] = None ,
tp_size : Optional [ int ] = None ,
ep_size : Optional [ int ] = None ,
dp_size : Optional [ int ] = None ,
prefix : str = " " ,
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
activation : str = " silu " ,
apply_router_weight_on_input : bool = False ,
) :
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
2025-08-08 10:20:23 +08:00
super ( ) . __init__ (
num_experts = num_experts ,
top_k = top_k ,
hidden_size = hidden_size ,
intermediate_size = intermediate_size ,
params_dtype = params_dtype ,
reduce_results = reduce_results ,
renormalize = renormalize ,
use_grouped_topk = use_grouped_topk ,
num_expert_group = num_expert_group ,
topk_group = topk_group ,
quant_config = quant_config ,
tp_size = tp_size ,
ep_size = ep_size ,
dp_size = dp_size ,
prefix = prefix ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
activation = activation ,
2025-08-14 09:17:50 +08:00
apply_router_weight_on_input = apply_router_weight_on_input ,
2025-08-08 10:20:23 +08:00
)
2025-06-09 19:28:11 +08:00
AscendFusedMoE . moe_counter + = 1
self . moe_instance_id = AscendFusedMoE . moe_counter
2025-04-19 17:38:18 +08:00
if params_dtype is None :
params_dtype = torch . get_default_dtype ( )
2025-05-16 12:14:55 +08:00
vllm_config = get_current_vllm_config ( )
2025-07-04 17:54:33 +08:00
self . moe_parallel_config = FusedMoEParallelConfig . make (
tp_size_ = ( tp_size if tp_size is not None else
get_tensor_model_parallel_world_size ( ) ) ,
dp_size_ = ( dp_size
if dp_size is not None else get_dp_group ( ) . world_size ) ,
vllm_parallel_config = vllm_config . parallel_config )
2025-05-28 21:18:41 +08:00
2025-04-19 17:38:18 +08:00
self . top_k = top_k
self . num_experts = num_experts
self . global_num_experts = num_experts
assert intermediate_size % self . tp_size == 0
self . intermediate_size_per_partition = intermediate_size / / self . tp_size
self . reduce_results = reduce_results
self . renormalize = renormalize
self . use_grouped_topk = use_grouped_topk
if self . use_grouped_topk :
assert num_expert_group is not None and topk_group is not None
self . num_expert_group = num_expert_group
self . topk_group = topk_group
self . custom_routing_function = custom_routing_function
self . scoring_func = scoring_func
self . e_score_correction_bias = e_score_correction_bias
self . expert_map = None
self . activation = activation
2025-06-09 19:28:11 +08:00
self . log2phy = None
self . global_redundant_expert_num = 0
2025-04-19 17:38:18 +08:00
2025-07-10 12:07:05 +08:00
is_deepseek_v3_r1 = self . global_num_experts == 256
2025-07-11 08:53:17 +08:00
self . rm_router_logits = get_rm_router_logits_state (
self . moe_parallel_config . ep_size , self . dp_size , is_deepseek_v3_r1 )
2025-07-10 12:07:05 +08:00
self . all_reduce_merge = get_all_reduce_merge_state (
self . moe_parallel_config . ep_size , is_deepseek_v3_r1 )
2025-06-09 19:28:11 +08:00
ascend_config = get_ascend_config ( )
2025-09-17 10:36:43 +08:00
self . dynamic_eplb = ascend_config . dynamic_eplb
self . expert_map_path = ascend_config . expert_map_path
self . global_redundant_expert_num = ascend_config . init_redundancy_expert
self . global_num_experts = num_experts + self . global_redundant_expert_num
# static eplb initializing with expert_map_path
if self . expert_map_path and os . path . exists (
self . expert_map_path ) and os . access ( self . expert_map_path ,
os . R_OK ) :
self . expert_load_balancer = ExpertLoadBalancer (
self . expert_map_path , self . global_num_experts )
self . local_num_experts , self . expert_map = (
self . expert_load_balancer . get_rank_placement_map (
self . moe_instance_id , self . ep_rank ) )
self . log2phy = self . expert_load_balancer . get_rank_log2phy_map (
self . moe_instance_id , self . ep_rank ) . npu ( )
self . global_redundant_expert_num = (
self . expert_load_balancer . get_global_redundant_expert_num ( ) )
2025-06-09 19:28:11 +08:00
else :
2025-09-17 10:36:43 +08:00
# init moe.
2025-06-09 19:28:11 +08:00
self . local_num_experts , self . expert_map = determine_expert_map (
2025-09-17 10:36:43 +08:00
self . ep_size , self . ep_rank , self . global_num_experts )
# dynamic eplb initializing with not expert_map_path
if self . dynamic_eplb :
self . global_redundant_expert_num = ascend_config . init_redundancy_expert
self . local_num_experts , self . expert_map = determine_default_expert_map (
self . global_num_experts , self . ep_size , self . ep_rank ,
self . global_redundant_expert_num )
self . log2phy = determine_default_log2phy_map (
self . global_num_experts , self . ep_size , self . ep_rank ,
self . global_redundant_expert_num )
local_num_experts = ( torch . sum ( self . expert_map != - 1 )
if self . expert_map is not None else num_experts )
if self . dynamic_eplb :
self . moe_load = torch . zeros ( local_num_experts , dtype = torch . int64 )
2025-05-16 12:14:55 +08:00
2025-08-12 14:12:12 +08:00
self . enable_shared_expert_dp = ascend_config . enable_shared_expert_dp
2025-05-16 12:14:55 +08:00
2025-04-19 17:38:18 +08:00
if self . scoring_func != " softmax " and not self . use_grouped_topk :
raise ValueError ( " Only softmax scoring function is supported for "
" non-grouped topk. " )
2025-09-20 17:37:57 +08:00
if vllm_version_is ( " 0.10.2 " ) :
moe = FusedMoEConfig . make (
num_experts = self . global_num_experts ,
experts_per_token = top_k ,
hidden_dim = hidden_size ,
num_local_experts = self . local_num_experts ,
moe_parallel_config = self . moe_parallel_config ,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype = params_dtype ,
quant_config = quant_config )
else :
moe = FusedMoEConfig (
num_experts = self . global_num_experts ,
experts_per_token = top_k ,
hidden_dim = hidden_size ,
num_local_experts = self . local_num_experts ,
moe_parallel_config = self . moe_parallel_config ,
in_dtype = params_dtype ,
)
2025-08-22 17:09:08 +08:00
self . moe_config = moe
2025-09-29 09:12:49 +08:00
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
2025-08-22 17:09:08 +08:00
2025-05-28 21:18:41 +08:00
if quant_config is None :
2025-08-22 17:09:08 +08:00
self . quant_method = AscendUnquantizedFusedMoEMethod ( moe )
2025-05-28 21:18:41 +08:00
else :
self . quant_method = quant_config . get_quant_method ( self , prefix )
2025-05-16 12:14:55 +08:00
2025-04-19 17:38:18 +08:00
assert self . quant_method is not None
local_num_experts = torch . sum ( self . expert_map != - 1 ) \
if self . expert_map is not None else num_experts
2025-09-17 10:36:43 +08:00
self . moe_load = None
if self . dynamic_eplb :
self . moe_load = torch . zeros ( local_num_experts , dtype = torch . int64 )
2025-04-19 17:38:18 +08:00
moe_quant_params = {
" num_experts " : local_num_experts ,
" hidden_size " : hidden_size ,
" intermediate_size_per_partition " :
self . intermediate_size_per_partition ,
" params_dtype " : params_dtype ,
" weight_loader " : self . weight_loader ,
}
# need full intermediate size pre-sharding for WNA16 act order
if ( self . quant_method . __class__ . __name__
in ( " GPTQMarlinMoEMethod " , " CompressedTensorsWNA16MoEMethod " ) ) :
moe_quant_params [ " intermediate_size_full " ] = intermediate_size
2025-06-04 18:31:41 +08:00
self . ep_group = get_ep_group ( )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
# NOTE: self.tp_group is not expert_tp_group
self . tp_group = get_tp_group ( ) . device_group
2025-04-19 17:38:18 +08:00
self . quant_method . create_weights ( layer = self , * * moe_quant_params )
2025-09-16 11:06:00 +08:00
self . moe_config . tp_group = get_tp_group ( )
self . moe_config . dp_group = get_dp_group ( )
self . moe_config . ep_group = get_ep_group ( )
self . moe_config . mc2_group = get_mc2_group ( )
self . moe_config . num_global_redundant_experts = self . global_redundant_expert_num
2025-09-22 19:12:58 +08:00
setup_moe_comm_method ( self . moe_config )
2025-08-28 10:13:35 +08:00
2025-09-17 10:36:43 +08:00
def update_expert_map ( self , new_expert_map ) :
self . expert_map = new_expert_map
def get_map ( self ) :
return self . expert_map
def get_log2phy_map ( self ) :
return self . logical_to_physical_map
def clear_moe_load ( self ) :
if self . moe_load is not None :
self . moe_load . zero_ ( )
2025-04-19 17:38:18 +08:00
def forward ( self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
2025-05-15 09:19:55 +08:00
enable_force_load_balance : bool = False ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
top_k : Optional [ int ] = None ,
2025-06-25 19:56:49 +08:00
shared_experts : Optional [ Any ] = None ,
2025-07-11 08:53:17 +08:00
gate = None ,
2025-09-24 11:29:59 +08:00
replace_allreduce : bool = False ) :
2025-07-11 08:53:17 +08:00
2025-04-19 17:38:18 +08:00
assert self . quant_method is not None
if top_k :
real_top_k = top_k
else :
real_top_k = self . top_k
2025-07-28 14:06:20 +08:00
forward_context = get_forward_context ( )
mc2_mask = forward_context . mc2_mask
2025-07-29 23:53:19 +08:00
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share , dynamic_scale_for_share = None , None
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
if shared_experts :
2025-08-26 14:12:43 +08:00
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts ( hidden_states )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
2025-09-24 11:29:59 +08:00
if forward_context . sp_enabled :
2025-08-07 09:15:49 +08:00
replace_allreduce = True
2025-09-16 11:06:00 +08:00
hidden_states , router_logits = forward_context . moe_comm_method . prepare (
hidden_states = hidden_states ,
router_logits = router_logits ,
enable_shared_expert_dp = self . enable_shared_expert_dp ,
rm_router_logits = self . rm_router_logits ,
replace_allreduce = replace_allreduce ,
gate = gate )
2025-04-19 17:38:18 +08:00
# Matrix multiply.
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
e_hidden_states = self . quant_method . apply (
2025-04-19 17:38:18 +08:00
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
2025-04-23 16:23:25 +08:00
global_num_experts = self . global_num_experts ,
2025-04-19 17:38:18 +08:00
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
2025-05-15 09:19:55 +08:00
is_prefill = is_prefill ,
2025-06-05 23:39:38 +08:00
enable_force_load_balance = enable_force_load_balance ,
2025-06-09 19:28:11 +08:00
log2phy = self . log2phy ,
global_redundant_expert_num = self . global_redundant_expert_num ,
2025-08-26 14:12:43 +08:00
shared_experts = None ,
2025-07-28 14:06:20 +08:00
mc2_mask = mc2_mask ,
2025-07-29 23:53:19 +08:00
quantized_x_for_share = quantized_x_for_share ,
dynamic_scale_for_share = dynamic_scale_for_share ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
)
2025-06-05 23:39:38 +08:00
2025-09-17 10:36:43 +08:00
group_list_type = None
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
if shared_experts :
2025-09-17 10:36:43 +08:00
if isinstance ( e_hidden_states ,
tuple ) and len ( e_hidden_states ) == 2 :
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
e_hidden_states , shared_hidden_states = e_hidden_states
2025-09-17 10:36:43 +08:00
if isinstance ( e_hidden_states , tuple ) and len ( e_hidden_states ) == 3 :
e_hidden_states , group_list_type , expert_tokens = e_hidden_states
if self . dynamic_eplb and group_list_type is not None :
self . moe_load + = expert_tokens if group_list_type else \
torch . cat ( [ expert_tokens [ : 1 ] , expert_tokens [ 1 : ] - expert_tokens [ : - 1 ] ] )
2025-09-16 11:06:00 +08:00
final_hidden_states = forward_context . moe_comm_method . finalize (
hidden_states = e_hidden_states ,
reduce_results = ( not self . all_reduce_merge ) )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
if shared_experts :
return final_hidden_states , shared_hidden_states
else :
return final_hidden_states
2025-06-07 16:46:58 +08:00
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp (
self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
real_top_k ,
enable_force_load_balance : bool = False ,
) :
hidden_states = self . quant_method . apply (
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
global_num_experts = self . global_num_experts ,
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
is_prefill = is_prefill ,
2025-08-02 09:49:10 +08:00
enable_force_load_balance = enable_force_load_balance ,
)
return hidden_states