2025-04-28 21:57:01 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
2026-02-06 15:28:49 +08:00
|
|
|
from collections.abc import Callable
|
2026-01-17 11:53:22 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from functools import wraps
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
import torch
|
2026-01-17 11:53:22 +08:00
|
|
|
import torch.nn.functional as F
|
2025-11-24 20:33:56 +08:00
|
|
|
from vllm.config import get_current_vllm_config
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
|
2025-08-12 21:10:20 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-10-11 14:04:02 +08:00
|
|
|
from vllm.logger import logger
|
2025-10-09 14:12:46 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
|
2026-02-26 10:22:47 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2026-02-25 09:17:29 +08:00
|
|
|
from vllm_ascend.utils import vllm_version_is
|
|
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import FusedMoEMethodBase # type: ignore
|
|
|
|
|
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import FusedMoERouter # type: ignore
|
|
|
|
|
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
|
|
|
|
|
2025-08-04 15:23:20 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2025-09-22 19:12:58 +08:00
|
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
|
|
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
|
|
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
2026-02-10 15:59:58 +08:00
|
|
|
from vllm_ascend.quantization.methods.base import QuantType
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm_ascend.utils import (
|
|
|
|
|
enable_sp,
|
|
|
|
|
maybe_trans_nz,
|
|
|
|
|
npu_stream_switch,
|
|
|
|
|
shared_expert_dp_enabled,
|
|
|
|
|
shared_experts_calculation_stream,
|
|
|
|
|
vllm_version_is,
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class FusedMoEResult:
|
|
|
|
|
routed_out: torch.Tensor
|
|
|
|
|
before_dispatch_evt: torch.npu.Event | None = None
|
|
|
|
|
before_combine_evt: torch.npu.Event | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class FusedMoEEvents:
|
|
|
|
|
before_routed_experts: torch.npu.Event
|
|
|
|
|
before_dispatch: torch.npu.Event | None = field(default=None)
|
|
|
|
|
before_combine: torch.npu.Event | None = field(default=None)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
|
2025-10-09 14:12:46 +08:00
|
|
|
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
|
|
def __init__(self, moe: FusedMoEConfig = None):
|
|
|
|
|
super().__init__(moe=moe)
|
2026-01-15 10:26:44 +08:00
|
|
|
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
2025-10-09 14:12:46 +08:00
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer):
|
2026-02-06 15:28:49 +08:00
|
|
|
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous()
|
2025-12-08 20:34:52 +08:00
|
|
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
|
2025-12-08 20:34:52 +08:00
|
|
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
[Feat.]: 310p support MOE models (#6530)
### What this PR does / why we need it?
This pull request integrates comprehensive support for Mixture of
Experts (MoE) models on the Ascend 310P device within the vllm-ascend
framework. It achieves this by introducing specialized modules for
expert selection, fused MoE layers, and optimized all-gather
communication. The changes also refine existing NPU operations, making
them more consistent and efficient for 310P, ultimately enhancing the
performance and compatibility of MoE models on this hardware.
Highlights
310P MoE Support: Introduces dedicated implementations for Mixture of
Experts (MoE) models on Ascend 310P devices, including new modules for
expert selection, fused MoE layers, and communication.
All-Gather Communication: Enforces the use of ALLGATHER communication
for MoE operations on 310P, optimizing data transfer and leveraging
NPU-specific token dispatching.
Simplified NPU Operations: Removes conditional type casting for
npu_swiglu and enables custom rotary embedding kernels unconditionally,
suggesting improved native support for 310P.
New MoE Classes Registered: Registers AscendFusedMoE310 and
AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the
system's custom operation registry.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p
- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-02-06 10:30:56 +08:00
|
|
|
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
|
|
|
|
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
use_grouped_topk: bool,
|
|
|
|
|
top_k: int,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
renormalize: bool,
|
|
|
|
|
topk_group: int | None = None,
|
|
|
|
|
num_expert_group: int | None = None,
|
|
|
|
|
custom_routing_function: Callable | None = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
|
|
|
|
routed_scaling_factor: float = 1.0,
|
|
|
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
|
|
|
global_num_experts: int = -1,
|
|
|
|
|
expert_map: torch.Tensor | None = None,
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
2026-02-12 10:55:34 +08:00
|
|
|
activation: str = "silu",
|
2026-02-06 15:28:49 +08:00
|
|
|
enable_force_load_balance: bool = False,
|
|
|
|
|
log2phy: torch.Tensor = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> torch.Tensor:
|
2025-12-31 17:06:55 +08:00
|
|
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
|
|
|
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
2025-10-15 09:08:31 +08:00
|
|
|
topk_weights, topk_ids = select_experts(
|
2025-10-09 14:12:46 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
num_expert_group=num_expert_group,
|
|
|
|
|
custom_routing_function=custom_routing_function,
|
|
|
|
|
scoring_func=scoring_func,
|
|
|
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
|
|
|
e_score_correction_bias=e_score_correction_bias,
|
2026-02-06 15:28:49 +08:00
|
|
|
global_num_experts=global_num_experts,
|
|
|
|
|
)
|
2026-02-26 10:22:47 +08:00
|
|
|
if layer.vllm_config.model_config is not None and layer.vllm_config.model_config.enable_return_routed_experts:
|
|
|
|
|
capturer = RoutedExpertsCapturer.get_instance()
|
|
|
|
|
if capturer is not None:
|
|
|
|
|
capturer.capture(
|
|
|
|
|
layer_id=layer.layer_id,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2025-12-31 17:06:55 +08:00
|
|
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
|
|
|
|
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
|
|
|
|
|
expert_indices=topk_ids,
|
|
|
|
|
expert_scales=topk_weights,
|
|
|
|
|
num_experts=global_num_experts,
|
|
|
|
|
zero_expert_type=zero_expert_type,
|
|
|
|
|
hidden_states=x,
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-09 14:12:46 +08:00
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
|
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
|
|
|
# currently it is only activated when doing profile runs.
|
2025-11-24 20:33:56 +08:00
|
|
|
if enable_force_load_balance:
|
2026-02-06 15:28:49 +08:00
|
|
|
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
|
|
|
|
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
2025-12-31 17:06:55 +08:00
|
|
|
final_hidden_states = moe_comm_method.fused_experts(
|
2025-10-09 14:12:46 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
2026-02-12 10:55:34 +08:00
|
|
|
w1_bias=layer.w13_bias if self.moe.has_bias else None,
|
|
|
|
|
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
|
|
|
|
activation=activation,
|
2025-10-09 14:12:46 +08:00
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
2025-10-22 11:41:30 +08:00
|
|
|
dynamic_eplb=self.dynamic_eplb,
|
2026-01-26 14:28:16 +08:00
|
|
|
log2phy=log2phy,
|
2026-02-06 15:28:49 +08:00
|
|
|
mc2_mask=kwargs.get("mc2_mask"),
|
|
|
|
|
)
|
2025-12-31 17:06:55 +08:00
|
|
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
|
|
|
|
final_hidden_states += zero_expert_result
|
|
|
|
|
return final_hidden_states
|
2025-08-30 11:00:35 +08:00
|
|
|
|
|
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
# Please remove this inheritance after extending vllm, todo(wxs)
|
|
|
|
|
class AscendMoERunner(DefaultMoERunner):
|
|
|
|
|
"""
|
|
|
|
|
Default implementation of the MoE runner for executing Mixture of Experts layers.
|
|
|
|
|
|
|
|
|
|
This class provides a comprehensive implementation for running MoE computations
|
|
|
|
|
with support for:
|
|
|
|
|
- Expert routing and token dispatching
|
|
|
|
|
- Shared experts computation with optional parallel execution using CUDA streams
|
|
|
|
|
- Data parallel (DP) chunking for large batch processing
|
|
|
|
|
- Tensor model parallel and expert parallel operations
|
|
|
|
|
- Various quantization methods and custom operators
|
|
|
|
|
- Both monolithic and decomposed expert execution paths
|
|
|
|
|
|
|
|
|
|
The runner handles the complete MoE forward pass including routing tokens to
|
|
|
|
|
experts, executing expert computations, and combining results. It supports
|
|
|
|
|
advanced features like overlapped execution of shared experts and optimized
|
|
|
|
|
kernels for different parallel execution modes.
|
|
|
|
|
|
|
|
|
|
Eventually, this class will be split up and specialized for different
|
|
|
|
|
configurations, e.g. the presence or absence of shared experts, a gate, etc.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
moe_config: FusedMoEConfig,
|
|
|
|
|
router: FusedMoERouter,
|
|
|
|
|
routed_input_transform: torch.nn.Module | None,
|
|
|
|
|
gate: torch.nn.Module | None,
|
|
|
|
|
shared_experts: torch.nn.Module | None,
|
|
|
|
|
quant_method: FusedMoEMethodBase,
|
|
|
|
|
reduce_results: bool,
|
|
|
|
|
enable_dbo: bool,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
layer,
|
|
|
|
|
moe_config,
|
|
|
|
|
router,
|
|
|
|
|
routed_input_transform,
|
|
|
|
|
gate,
|
|
|
|
|
shared_experts,
|
|
|
|
|
quant_method,
|
|
|
|
|
reduce_results,
|
|
|
|
|
enable_dbo,
|
|
|
|
|
)
|
|
|
|
|
if self.shared_experts is None:
|
|
|
|
|
self.moe_forward = torch.ops.vllm.moe_forward
|
|
|
|
|
else:
|
|
|
|
|
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
|
|
|
|
|
|
|
|
|
def forward_impl(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
shared_input: torch.Tensor | None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Override the default forward_impl to use Ascend-specific implementation.
|
|
|
|
|
This delegates to the layer's forward_impl method which contains the
|
|
|
|
|
Ascend-specific MoE computation logic.
|
|
|
|
|
"""
|
|
|
|
|
result = layer.forward_impl(hidden_states, router_logits)
|
|
|
|
|
# If the layer has shared experts, forward_impl returns a tuple (shared_out, routed_out)
|
|
|
|
|
# Otherwise, it returns just routed_out
|
|
|
|
|
# The torch op expects the same return type based on whether it's moe_forward or moe_forward_shared
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
class AscendFusedMoE(FusedMoE):
|
2025-09-17 10:36:43 +08:00
|
|
|
moe_counter = -1
|
2026-02-06 15:28:49 +08:00
|
|
|
gate_stream: torch.npu.Stream | None = None
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-09-16 14:13:07 +08:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
2025-10-09 14:12:46 +08:00
|
|
|
num_experts = kwargs["num_experts"]
|
|
|
|
|
intermediate_size = kwargs["intermediate_size"]
|
|
|
|
|
|
2025-09-17 10:36:43 +08:00
|
|
|
AscendFusedMoE.moe_counter += 1
|
|
|
|
|
self.moe_instance_id = AscendFusedMoE.moe_counter
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2025-12-15 19:54:23 +08:00
|
|
|
self._expert_map = None
|
2025-10-09 14:12:46 +08:00
|
|
|
self.log2phy = None
|
|
|
|
|
|
|
|
|
|
if self.quant_config is None:
|
2026-02-06 15:28:49 +08:00
|
|
|
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe_config)
|
2025-10-09 14:12:46 +08:00
|
|
|
else:
|
2026-02-06 15:28:49 +08:00
|
|
|
self.quant_method = self.quant_config.get_quant_method(self, self.layer_name)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
self.moe_config.tp_group = get_tp_group()
|
|
|
|
|
self.moe_config.dp_group = get_dp_group()
|
|
|
|
|
self.moe_config.ep_group = get_ep_group()
|
|
|
|
|
self.moe_config.mc2_group = get_mc2_group()
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
self.moe_config.supports_eplb = self.quant_method.supports_eplb
|
2025-09-17 10:36:43 +08:00
|
|
|
ascend_config = get_ascend_config()
|
2025-12-14 09:34:13 +08:00
|
|
|
# flashcommon3 gate stream
|
|
|
|
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
|
|
|
|
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
|
|
|
|
|
AscendFusedMoE.gate_stream = torch.npu.Stream()
|
2025-10-20 15:31:34 +08:00
|
|
|
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
|
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
2026-02-06 15:28:49 +08:00
|
|
|
dtype=vllm_config.model_config.dtype
|
|
|
|
|
)
|
2025-11-21 14:24:35 +08:00
|
|
|
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
# init moe
|
2026-01-15 10:26:44 +08:00
|
|
|
eplb_config = ascend_config.eplb_config
|
2026-01-19 09:23:28 +08:00
|
|
|
self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
|
2026-02-06 15:28:49 +08:00
|
|
|
eplb_config, self.moe_instance_id, self.moe_config
|
|
|
|
|
)
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
2026-02-06 15:28:49 +08:00
|
|
|
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None)
|
2026-03-09 17:53:54 +08:00
|
|
|
self.local_num_experts = self.global_num_experts // self.ep_size
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
if self._expert_map is not None:
|
2025-11-21 14:24:35 +08:00
|
|
|
logger.info_once(
|
|
|
|
|
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
|
|
|
|
" number of experts: %s/%s. Experts local to global index map:"
|
2026-02-06 15:28:49 +08:00
|
|
|
" %s.",
|
|
|
|
|
self.ep_rank,
|
|
|
|
|
self.ep_size,
|
|
|
|
|
self.local_num_experts,
|
2025-11-21 14:24:35 +08:00
|
|
|
self.global_num_experts,
|
2026-02-06 15:28:49 +08:00
|
|
|
get_compressed_expert_map(self._expert_map),
|
|
|
|
|
)
|
2025-09-17 10:36:43 +08:00
|
|
|
if self.dynamic_eplb:
|
2026-02-06 15:28:49 +08:00
|
|
|
self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-10-09 14:12:46 +08:00
|
|
|
self.moe_config.num_experts = self.global_num_experts
|
|
|
|
|
self.moe_config.num_local_experts = self.local_num_experts
|
2026-01-06 17:22:36 +08:00
|
|
|
self.moe_config.global_redundant_expert_num = self.global_redundant_expert_num
|
2025-10-09 14:12:46 +08:00
|
|
|
|
|
|
|
|
moe_quant_params = {
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
"num_experts": self.local_num_experts,
|
2025-10-09 14:12:46 +08:00
|
|
|
"hidden_size": self.hidden_size,
|
2026-02-06 15:28:49 +08:00
|
|
|
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
2025-10-09 14:12:46 +08:00
|
|
|
"params_dtype": self.params_dtype,
|
|
|
|
|
"weight_loader": self.weight_loader,
|
|
|
|
|
}
|
|
|
|
|
# need full intermediate size pre-sharding for WNA16 act order
|
2026-02-06 15:28:49 +08:00
|
|
|
if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"):
|
2025-10-09 14:12:46 +08:00
|
|
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
2025-10-11 14:04:02 +08:00
|
|
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
|
|
|
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
2026-03-06 10:36:10 +08:00
|
|
|
self.enable_npugraph_ex_static_kernel = ascend_config.ascend_compilation_config.enable_static_kernel
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2025-11-13 11:02:31 +08:00
|
|
|
setup_moe_comm_method(self.moe_config)
|
|
|
|
|
self.quant_type = self._get_quant_type()
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
self.runner = self._init_runner()
|
|
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
|
|
|
|
|
def _init_runner(self):
|
|
|
|
|
# Storing the runner in the FusedMoE is an intermediate state, eventually
|
|
|
|
|
# the runner will own the FusedMoE layer and provide the execution interface
|
|
|
|
|
# for MoE ops.
|
|
|
|
|
return AscendMoERunner(
|
|
|
|
|
layer=self,
|
|
|
|
|
moe_config=self.moe_config,
|
|
|
|
|
router=self.router,
|
|
|
|
|
routed_input_transform=self._routed_input_transform,
|
|
|
|
|
gate=self.gate,
|
|
|
|
|
shared_experts=self.shared_experts,
|
|
|
|
|
quant_method=self.quant_method,
|
|
|
|
|
reduce_results=self.reduce_results,
|
|
|
|
|
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
|
|
|
|
)
|
2025-11-13 11:02:31 +08:00
|
|
|
|
|
|
|
|
def _get_quant_type(self) -> QuantType:
|
2026-02-10 15:59:58 +08:00
|
|
|
quant_type = QuantType.NONE
|
|
|
|
|
method = getattr(self.quant_method, "quant_method", None)
|
2025-11-13 11:02:31 +08:00
|
|
|
|
2026-02-10 15:59:58 +08:00
|
|
|
if method is not None:
|
|
|
|
|
quant_type = getattr(method, "quant_type", QuantType.NONE)
|
2025-11-13 11:02:31 +08:00
|
|
|
|
2026-02-10 15:59:58 +08:00
|
|
|
return quant_type
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-09-17 10:36:43 +08:00
|
|
|
def update_expert_map(self, new_expert_map):
|
2025-12-15 19:54:23 +08:00
|
|
|
self._expert_map = new_expert_map
|
2025-09-17 10:36:43 +08:00
|
|
|
|
|
|
|
|
def get_log2phy_map(self):
|
2025-11-29 15:18:29 +08:00
|
|
|
return self.log2phy
|
2025-09-17 10:36:43 +08:00
|
|
|
|
|
|
|
|
def clear_moe_load(self):
|
|
|
|
|
if self.moe_load is not None:
|
|
|
|
|
self.moe_load.zero_()
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
2025-09-09 18:19:56 +08:00
|
|
|
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
|
|
|
|
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
|
|
|
|
the outputs are already aggregated across tensor parallel ranks in the
|
|
|
|
|
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
|
|
|
|
outputs since each rank only has partial outputs.
|
|
|
|
|
"""
|
2026-02-06 15:28:49 +08:00
|
|
|
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
|
2025-09-09 18:19:56 +08:00
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
self.ensure_moe_quant_config_init()
|
|
|
|
|
return self.runner.forward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
router_logits,
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-17 11:53:22 +08:00
|
|
|
def forward_impl( # type: ignore[override]
|
2026-02-06 15:28:49 +08:00
|
|
|
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False
|
|
|
|
|
) -> torch.Tensor | FusedMoEResult:
|
2025-08-26 19:05:23 +08:00
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
2025-10-09 14:12:46 +08:00
|
|
|
forward_context = get_forward_context()
|
2026-03-06 10:36:10 +08:00
|
|
|
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
|
|
|
|
|
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
|
|
|
|
|
if self.enable_npugraph_ex_static_kernel:
|
|
|
|
|
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
2025-10-10 09:00:07 +08:00
|
|
|
|
|
|
|
|
# Load balancing for token distribution among experts in dummy_run
|
|
|
|
|
# TODO: The community only considers load balancing when DP > 1.
|
|
|
|
|
# This approach may overlook some extreme scenarios.
|
2025-10-09 14:12:46 +08:00
|
|
|
enable_force_load_balance = forward_context.in_profile_run
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
forward_context = get_forward_context()
|
2025-12-14 09:34:13 +08:00
|
|
|
if self.multistream_overlap_gate:
|
|
|
|
|
assert AscendFusedMoE.gate_stream is not None
|
|
|
|
|
fc3_context = get_flash_common3_context()
|
|
|
|
|
assert fc3_context is not None
|
|
|
|
|
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
|
2026-02-06 15:28:49 +08:00
|
|
|
with npu_stream_switch(AscendFusedMoE.gate_stream, enabled=self.multistream_overlap_gate):
|
2025-12-14 09:34:13 +08:00
|
|
|
# share_expert
|
|
|
|
|
assert fc3_context.shared_experts is not None
|
|
|
|
|
shared_out = fc3_context.shared_experts(hidden_states)
|
|
|
|
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
|
|
|
|
moe_comm_type = forward_context.moe_comm_type
|
2026-02-06 15:28:49 +08:00
|
|
|
if (
|
|
|
|
|
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
|
|
|
|
and not shared_expert_dp_enabled()
|
|
|
|
|
):
|
2025-12-14 09:34:13 +08:00
|
|
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
|
|
|
set_flash_common3_context(shared_out=shared_out)
|
|
|
|
|
|
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
|
topk_group=self.topk_group,
|
|
|
|
|
num_expert_group=self.num_expert_group,
|
|
|
|
|
custom_routing_function=self.custom_routing_function,
|
|
|
|
|
scoring_func=self.scoring_func,
|
|
|
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
2026-02-06 15:28:49 +08:00
|
|
|
global_num_experts=self.global_num_experts,
|
|
|
|
|
)
|
2025-12-14 09:34:13 +08:00
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
|
|
|
|
|
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
|
|
|
|
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
2025-12-14 09:34:13 +08:00
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
2025-12-14 09:34:13 +08:00
|
|
|
|
2025-10-22 11:41:30 +08:00
|
|
|
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
2025-09-24 11:29:59 +08:00
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
2026-02-27 08:27:41 +08:00
|
|
|
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
2025-11-13 11:02:31 +08:00
|
|
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
2026-02-06 15:28:49 +08:00
|
|
|
quant_type=self.quant_type,
|
|
|
|
|
)
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-12-14 09:34:13 +08:00
|
|
|
# Make sure the default stream waits for the gate stream to finish.
|
|
|
|
|
if self.multistream_overlap_gate:
|
|
|
|
|
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
|
|
|
|
|
|
2025-11-04 16:49:58 +08:00
|
|
|
if isinstance(hidden_states, tuple):
|
|
|
|
|
hidden_states, pertoken_scale = hidden_states
|
|
|
|
|
else:
|
|
|
|
|
pertoken_scale = None
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
# Matrix multiply.
|
2025-12-31 14:24:37 +08:00
|
|
|
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
2025-08-26 19:05:23 +08:00
|
|
|
layer=self,
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
2025-11-04 16:49:58 +08:00
|
|
|
pertoken_scale=pertoken_scale,
|
2025-08-26 19:05:23 +08:00
|
|
|
top_k=self.top_k,
|
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
|
|
|
global_num_experts=self.global_num_experts,
|
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
|
|
|
expert_map=self._expert_map,
|
2025-08-26 19:05:23 +08:00
|
|
|
topk_group=self.topk_group,
|
|
|
|
|
num_expert_group=self.num_expert_group,
|
|
|
|
|
custom_routing_function=self.custom_routing_function,
|
|
|
|
|
scoring_func=self.scoring_func,
|
2025-12-31 17:06:55 +08:00
|
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
2025-08-26 19:05:23 +08:00
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
|
|
|
activation=self.activation,
|
|
|
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
2025-10-09 14:12:46 +08:00
|
|
|
enable_force_load_balance=enable_force_load_balance,
|
|
|
|
|
log2phy=self.log2phy,
|
2025-10-22 11:41:30 +08:00
|
|
|
global_redundant_expert_num=self.global_redundant_expert_num,
|
2026-02-06 15:28:49 +08:00
|
|
|
mc2_mask=mc2_mask,
|
|
|
|
|
)
|
2025-10-09 14:12:46 +08:00
|
|
|
|
2025-12-31 14:24:37 +08:00
|
|
|
if self.dynamic_eplb:
|
|
|
|
|
expert_tokens = fused_experts_results.expert_tokens
|
|
|
|
|
group_list_type = fused_experts_results.group_list_type
|
2026-02-06 15:28:49 +08:00
|
|
|
assert expert_tokens is not None and group_list_type is not None, (
|
2025-12-31 14:24:37 +08:00
|
|
|
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
|
2026-02-06 15:28:49 +08:00
|
|
|
)
|
|
|
|
|
local_load = (
|
|
|
|
|
expert_tokens
|
|
|
|
|
if group_list_type == 1
|
|
|
|
|
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
|
|
|
|
)
|
2026-01-26 17:18:46 +08:00
|
|
|
self.moe_load.add_(local_load)
|
2025-12-31 14:24:37 +08:00
|
|
|
routed_out = forward_context.moe_comm_method.finalize(
|
|
|
|
|
hidden_states=fused_experts_results.routed_out,
|
2025-10-22 11:41:30 +08:00
|
|
|
reduce_results=self.reduce_results,
|
2026-02-06 15:28:49 +08:00
|
|
|
context_metadata=context_metadata,
|
|
|
|
|
)
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2026-01-17 11:53:22 +08:00
|
|
|
if return_with_event:
|
|
|
|
|
return FusedMoEResult(
|
|
|
|
|
routed_out=routed_out,
|
|
|
|
|
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
|
2026-02-06 15:28:49 +08:00
|
|
|
before_combine_evt=fused_experts_results.before_combine_evt,
|
|
|
|
|
)
|
2026-01-17 11:53:22 +08:00
|
|
|
else:
|
|
|
|
|
# The vLLM FusedMoE forward_impl does not return events.
|
|
|
|
|
return routed_out
|
2025-08-26 19:05:23 +08:00
|
|
|
|
|
|
|
|
|
2025-09-19 19:05:01 +08:00
|
|
|
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
2025-09-09 18:19:56 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
shared_experts: torch.nn.Module,
|
2026-02-06 15:28:49 +08:00
|
|
|
gate: torch.nn.Module | None = None,
|
2025-09-09 18:19:56 +08:00
|
|
|
use_overlapped: bool = True,
|
2026-02-06 15:28:49 +08:00
|
|
|
routed_input_transform: torch.nn.Module | None = None,
|
2025-09-09 18:19:56 +08:00
|
|
|
**kwargs,
|
|
|
|
|
):
|
2025-09-19 19:05:01 +08:00
|
|
|
AscendFusedMoE.__init__(self, **kwargs)
|
2025-10-25 15:36:32 +08:00
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
self._routed_input_transform = routed_input_transform
|
2025-09-09 18:19:56 +08:00
|
|
|
self._shared_experts = shared_experts
|
|
|
|
|
self.use_overlapped = use_overlapped
|
2025-09-19 11:06:45 +08:00
|
|
|
self.shared_expert_stream = None
|
|
|
|
|
ascend_config = get_ascend_config()
|
2026-02-06 15:28:49 +08:00
|
|
|
self.multistream_overlap_shared_expert = (
|
|
|
|
|
ascend_config.multistream_overlap_shared_expert and self._shared_experts is not None
|
|
|
|
|
)
|
|
|
|
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate and self._shared_experts is not None
|
2025-10-15 19:36:32 +08:00
|
|
|
if enable_sp():
|
2026-02-06 15:28:49 +08:00
|
|
|
logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.")
|
2025-09-09 18:19:56 +08:00
|
|
|
|
2025-10-25 15:36:32 +08:00
|
|
|
self._gate = gate
|
2026-02-27 16:05:21 +08:00
|
|
|
if not vllm_version_is("0.16.0"):
|
2026-02-25 09:17:29 +08:00
|
|
|
# Recreate the runner with the correct shared_experts parameter
|
|
|
|
|
# The parent class created the runner before self._shared_experts was set
|
|
|
|
|
self.runner = self._init_runner()
|
2025-10-25 15:36:32 +08:00
|
|
|
|
2026-01-29 08:47:20 +08:00
|
|
|
if self.multistream_overlap_shared_expert:
|
|
|
|
|
# Wrap the quant_method's process_weights_after_loading to validate that
|
|
|
|
|
# splitting shared expert computation (gate_up projection + activation,
|
|
|
|
|
# then down projection) yields identical results to integrated
|
|
|
|
|
# computation after weight loading.
|
|
|
|
|
original_process_weights = self.quant_method.process_weights_after_loading
|
2026-01-17 11:53:22 +08:00
|
|
|
|
2026-01-29 08:47:20 +08:00
|
|
|
@wraps(original_process_weights)
|
|
|
|
|
def wrapped_process_weights(*args, **kwargs):
|
|
|
|
|
result = original_process_weights(*args, **kwargs)
|
|
|
|
|
self._validate_shared_expert_consistency()
|
|
|
|
|
return result
|
2026-01-17 11:53:22 +08:00
|
|
|
|
2026-01-29 08:47:20 +08:00
|
|
|
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
2026-02-06 15:28:49 +08:00
|
|
|
shared_gate_up, _ = self._shared_experts.gate_up_proj(hidden_states) # type: ignore
|
2026-01-17 11:53:22 +08:00
|
|
|
return shared_gate_up
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def _shared_experts_part2(self, hidden_states: torch.Tensor, shared_gate_up: torch.Tensor):
|
|
|
|
|
shared_act = self._shared_experts.act_fn(shared_gate_up) # type: ignore
|
|
|
|
|
shared_out, _ = self._shared_experts.down_proj(shared_act) # type: ignore
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
# Qwen3-Next specific gating mechanism
|
2026-02-06 15:28:49 +08:00
|
|
|
if hasattr(self._shared_experts, "expert_gate") and self._shared_experts.expert_gate is not None:
|
2026-01-23 09:45:08 +08:00
|
|
|
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
2026-01-17 11:53:22 +08:00
|
|
|
shared_out = F.sigmoid(gate_out) * shared_out
|
|
|
|
|
return shared_out
|
|
|
|
|
|
|
|
|
|
def _validate_shared_expert_consistency(self):
|
|
|
|
|
"""Validate that split shared expert computation matches integrated
|
|
|
|
|
computation."""
|
2026-02-06 15:28:49 +08:00
|
|
|
test_input = (
|
|
|
|
|
torch.rand(10, self.hidden_size, device="npu", dtype=self.moe_config.in_dtype) * 2 - 1
|
|
|
|
|
) # Random input for testing, scoped to [-1, 1]
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
integrated_out = self._shared_experts(test_input)
|
|
|
|
|
part1_out = self._shared_experts_part1(test_input)
|
|
|
|
|
split_out = self._shared_experts_part2(test_input, part1_out)
|
|
|
|
|
|
|
|
|
|
if not torch.allclose(integrated_out, split_out):
|
|
|
|
|
diff = (integrated_out - split_out).abs()
|
2026-02-06 15:28:49 +08:00
|
|
|
logger.error("SharedFusedMoE shared experts split computation does not match the integrated computation.")
|
2026-01-17 11:53:22 +08:00
|
|
|
logger.error(f"Max absolute difference: {diff.max().item()}")
|
2026-02-06 15:28:49 +08:00
|
|
|
logger.error(
|
|
|
|
|
"Integrated output - sum: %s, norm: %s", integrated_out.sum().item(), integrated_out.norm().item()
|
|
|
|
|
)
|
|
|
|
|
logger.error("Split output - sum: %s, norm: %s", split_out.sum().item(), split_out.norm().item())
|
2026-01-17 11:53:22 +08:00
|
|
|
raise ValueError(
|
2026-02-06 15:28:49 +08:00
|
|
|
"SharedFusedMoE shared experts split computation does not match the integrated computation."
|
|
|
|
|
)
|
|
|
|
|
logger.info_once("SharedFusedMoE shared experts split computation matches the integrated computation.")
|
2026-01-17 11:53:22 +08:00
|
|
|
|
2025-10-25 15:36:32 +08:00
|
|
|
@property
|
2026-02-06 15:28:49 +08:00
|
|
|
def gate(self) -> torch.nn.Module | None:
|
2025-10-25 15:36:32 +08:00
|
|
|
return self._gate if self.use_overlapped else None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def is_internal_router(self) -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
2025-11-26 11:48:58 +08:00
|
|
|
@property
|
|
|
|
|
def use_dp_chunking(self) -> bool:
|
|
|
|
|
"""This func routes to the chunked forward path using the FlashInfer Cutlass kernel
|
|
|
|
|
only when data parallelism (DP) is enabled. Thus just returning False in vllm-ascend
|
|
|
|
|
"""
|
|
|
|
|
return False
|
|
|
|
|
|
2025-09-09 18:19:56 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
2026-02-02 15:57:55 +08:00
|
|
|
if self._shared_experts is None:
|
|
|
|
|
fused_out = AscendFusedMoE.forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
)
|
|
|
|
|
shared_out = None
|
|
|
|
|
return shared_out, fused_out
|
2025-09-28 21:31:55 +08:00
|
|
|
shared_out, fused_out = AscendFusedMoE.forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
)
|
|
|
|
|
return shared_out, fused_out
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents):
|
2026-01-29 08:47:20 +08:00
|
|
|
if self._shared_experts is None:
|
|
|
|
|
return None
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
def maybe_wait_event(evt: torch.npu.Event | None):
|
|
|
|
|
if evt is not None:
|
|
|
|
|
torch.npu.current_stream().wait_event(evt)
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
with npu_stream_switch(shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert):
|
2026-01-17 11:53:22 +08:00
|
|
|
# Ensure the shared experts wait for hidden_states to be ready.
|
2026-02-06 15:28:49 +08:00
|
|
|
torch.npu.current_stream().wait_event(fused_moe_evts.before_routed_experts)
|
2026-01-17 11:53:22 +08:00
|
|
|
# Execute the gate projection and activation concurrently with the
|
|
|
|
|
# dispatch communication.
|
|
|
|
|
maybe_wait_event(fused_moe_evts.before_dispatch)
|
|
|
|
|
part1_out = self._shared_experts_part1(hidden_states)
|
|
|
|
|
# Execute the down projection concurrently with the combine
|
|
|
|
|
# communication.
|
|
|
|
|
maybe_wait_event(fused_moe_evts.before_combine)
|
|
|
|
|
shared_out = self._shared_experts_part2(hidden_states, part1_out)
|
|
|
|
|
|
|
|
|
|
# Make sure the default stream waits for the shared experts stream to
|
|
|
|
|
# finish.
|
|
|
|
|
if self.multistream_overlap_shared_expert:
|
2026-02-06 15:28:49 +08:00
|
|
|
torch.npu.current_stream().wait_stream(shared_experts_calculation_stream())
|
2026-01-17 11:53:22 +08:00
|
|
|
|
|
|
|
|
# NOTE: This is exactly the opposite of
|
|
|
|
|
# `maybe_all_reduce_tensor_model_parallel`
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
moe_comm_type = forward_context.moe_comm_type
|
2026-02-06 15:28:49 +08:00
|
|
|
if (
|
|
|
|
|
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
|
|
|
|
and not shared_expert_dp_enabled()
|
|
|
|
|
):
|
2026-01-17 11:53:22 +08:00
|
|
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
|
|
|
return shared_out
|
|
|
|
|
|
|
|
|
|
def forward_impl( # type: ignore[override]
|
2026-02-06 15:28:49 +08:00
|
|
|
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
|
|
|
|
):
|
2026-01-17 11:53:22 +08:00
|
|
|
if self.multistream_overlap_gate:
|
2025-12-14 09:34:13 +08:00
|
|
|
set_flash_common3_context(shared_experts=self._shared_experts)
|
2025-09-19 11:06:45 +08:00
|
|
|
|
2026-01-17 11:53:22 +08:00
|
|
|
before_routed_experts = torch.npu.current_stream().record_event()
|
|
|
|
|
fused_moe_results = AscendFusedMoE.forward_impl(
|
2025-09-19 19:05:01 +08:00
|
|
|
self,
|
2025-09-09 18:19:56 +08:00
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
2026-01-17 11:53:22 +08:00
|
|
|
return_with_event=True,
|
2025-09-09 18:19:56 +08:00
|
|
|
)
|
2026-01-17 11:53:22 +08:00
|
|
|
routed_out = fused_moe_results.routed_out
|
2025-12-14 09:34:13 +08:00
|
|
|
|
2026-02-02 15:57:55 +08:00
|
|
|
if self._shared_experts is None:
|
|
|
|
|
return routed_out
|
|
|
|
|
|
2026-01-17 11:53:22 +08:00
|
|
|
if self.multistream_overlap_gate:
|
2025-12-14 09:34:13 +08:00
|
|
|
fc3_context = get_flash_common3_context()
|
|
|
|
|
assert fc3_context is not None
|
|
|
|
|
shared_out = fc3_context.shared_out
|
2026-01-17 11:53:22 +08:00
|
|
|
else:
|
|
|
|
|
shared_out = self._forward_shared_experts(
|
|
|
|
|
hidden_states,
|
|
|
|
|
FusedMoEEvents(
|
|
|
|
|
before_routed_experts=before_routed_experts,
|
|
|
|
|
before_dispatch=fused_moe_results.before_dispatch_evt,
|
|
|
|
|
before_combine=fused_moe_results.before_combine_evt,
|
2026-02-06 15:28:49 +08:00
|
|
|
),
|
|
|
|
|
)
|
2025-12-14 09:34:13 +08:00
|
|
|
|
2025-12-31 14:24:37 +08:00
|
|
|
return shared_out, routed_out
|