### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
466 lines
20 KiB
Python
466 lines
20 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map)
|
|
from vllm.model_executor.layers.fused_moe.shared_fused_moe import \
|
|
SharedFusedMoE
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
|
from vllm_ascend.eplb.utils import moe_load_async_stream
|
|
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
|
set_flash_common3_context)
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
|
setup_moe_comm_method)
|
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
|
from vllm_ascend.quantization.w4a8_dynamic import \
|
|
AscendW4A8DynamicFusedMoEMethod
|
|
from vllm_ascend.quantization.w8a8_dynamic import \
|
|
AscendW8A8DynamicFusedMoEMethod
|
|
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
|
|
get_ascend_device_type, maybe_trans_nz,
|
|
npu_stream_switch, shared_expert_dp_enabled,
|
|
shared_experts_calculation_stream)
|
|
|
|
|
|
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
|
def __init__(self, moe: FusedMoEConfig = None):
|
|
|
|
super().__init__(moe=moe)
|
|
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
super(UnquantizedFusedMoEMethod,
|
|
self).process_weights_after_loading(layer)
|
|
|
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
|
1, 2).contiguous()
|
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
|
|
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
|
|
|
if get_ascend_device_type() != AscendDeviceType._310P:
|
|
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
|
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
|
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
enable_force_load_balance: bool = False,
|
|
shared_experts: Optional[Any] = None,
|
|
**kwargs) -> torch.Tensor:
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
random_matrix = torch.rand(topk_ids.size(0),
|
|
global_num_experts,
|
|
device=topk_ids.device)
|
|
topk_ids = torch.argsort(
|
|
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
return moe_comm_method.fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
shared_experts=shared_experts,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
dynamic_eplb=self.dynamic_eplb,
|
|
mc2_mask=kwargs.get("mc2_mask", None))
|
|
|
|
|
|
class AscendFusedMoE(FusedMoE):
|
|
moe_counter = -1
|
|
gate_stream: Optional[torch.npu.Stream] = None
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
num_experts = kwargs["num_experts"]
|
|
intermediate_size = kwargs["intermediate_size"]
|
|
|
|
AscendFusedMoE.moe_counter += 1
|
|
self.moe_instance_id = AscendFusedMoE.moe_counter
|
|
|
|
self._expert_map = None
|
|
self.log2phy = None
|
|
|
|
if self.quant_config is None:
|
|
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
|
self.moe_config)
|
|
else:
|
|
self.quant_method = self.quant_config.get_quant_method(
|
|
self, self.layer_name)
|
|
|
|
assert self.quant_method is not None
|
|
|
|
self.moe_config.tp_group = get_tp_group()
|
|
self.moe_config.dp_group = get_dp_group()
|
|
self.moe_config.ep_group = get_ep_group()
|
|
self.moe_config.mc2_group = get_mc2_group()
|
|
self.moe_config.supports_eplb = self.quant_method.supports_eplb
|
|
ascend_config = get_ascend_config()
|
|
# flashcommon3 gate stream
|
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
|
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
|
|
AscendFusedMoE.gate_stream = torch.npu.Stream()
|
|
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
|
vllm_config = get_current_vllm_config()
|
|
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
|
dtype=vllm_config.model_config.dtype)
|
|
|
|
# init moe
|
|
self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
|
|
ascend_config, self.moe_instance_id, self.moe_config)
|
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
|
self.dynamic_eplb = (ascend_config.dynamic_eplb
|
|
or ascend_config.expert_map_record_path) and (
|
|
self.log2phy is not None)
|
|
self.local_num_experts = (torch.sum(
|
|
self._expert_map != -1) if self._expert_map is not None else
|
|
self.global_num_experts)
|
|
if self._expert_map is not None:
|
|
logger.info_once(
|
|
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
|
" number of experts: %s/%s. Experts local to global index map:"
|
|
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
|
|
self.global_num_experts,
|
|
get_compressed_expert_map(self._expert_map))
|
|
if self.dynamic_eplb:
|
|
self.moe_load = torch.zeros(self.local_num_experts,
|
|
dtype=torch.int64).npu()
|
|
|
|
self.moe_config.num_experts = self.global_num_experts
|
|
self.moe_config.num_local_experts = self.local_num_experts
|
|
self.moe_config.original_num_experts = num_experts
|
|
|
|
moe_quant_params = {
|
|
"num_experts": self.local_num_experts,
|
|
"hidden_size": self.hidden_size,
|
|
"intermediate_size_per_partition":
|
|
self.intermediate_size_per_partition,
|
|
"params_dtype": self.params_dtype,
|
|
"weight_loader": self.weight_loader,
|
|
}
|
|
# need full intermediate size pre-sharding for WNA16 act order
|
|
if (self.quant_method.__class__.__name__
|
|
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
|
|
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
|
|
|
setup_moe_comm_method(self.moe_config)
|
|
self.quant_type = self._get_quant_type()
|
|
|
|
def _get_quant_type(self) -> QuantType:
|
|
quant_method = self.quant_method
|
|
if not hasattr(quant_method,
|
|
"quant_method") or quant_method.quant_method is None:
|
|
return QuantType.NONE
|
|
|
|
method = quant_method.quant_method
|
|
|
|
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
|
return QuantType.W8A8
|
|
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
|
return QuantType.W4A8
|
|
else:
|
|
return QuantType.NONE
|
|
|
|
def update_expert_map(self, new_expert_map):
|
|
self._expert_map = new_expert_map
|
|
|
|
def get_log2phy_map(self):
|
|
return self.log2phy
|
|
|
|
def clear_moe_load(self):
|
|
if self.moe_load is not None:
|
|
self.moe_load.zero_()
|
|
|
|
def maybe_all_reduce_tensor_model_parallel(
|
|
self, final_hidden_states: torch.Tensor):
|
|
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
|
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
|
the outputs are already aggregated across tensor parallel ranks in the
|
|
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
|
outputs since each rank only has partial outputs.
|
|
"""
|
|
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
|
final_hidden_states)
|
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
|
|
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
|
quantized_x_for_share, dynamic_scale_for_share = None, None
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
# Load balancing for token distribution among experts in dummy_run
|
|
# TODO: The community only considers load balancing when DP > 1.
|
|
# This approach may overlook some extreme scenarios.
|
|
enable_force_load_balance = forward_context.in_profile_run
|
|
|
|
forward_context = get_forward_context()
|
|
if self.multistream_overlap_gate:
|
|
assert AscendFusedMoE.gate_stream is not None
|
|
fc3_context = get_flash_common3_context()
|
|
assert fc3_context is not None
|
|
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
|
|
with npu_stream_switch(AscendFusedMoE.gate_stream,
|
|
enabled=self.multistream_overlap_gate):
|
|
# share_expert
|
|
assert fc3_context.shared_experts is not None
|
|
shared_out = fc3_context.shared_experts(hidden_states)
|
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
|
moe_comm_type = forward_context.moe_comm_type
|
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
|
and not shared_expert_dp_enabled():
|
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
set_flash_common3_context(shared_out=shared_out)
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
scoring_func=self.scoring_func,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
global_num_experts=self.global_num_experts)
|
|
|
|
if isinstance(forward_context.moe_comm_method,
|
|
AllGatherCommImpl):
|
|
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
topk_weights, True, True)
|
|
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
topk_ids, True, True)
|
|
|
|
set_flash_common3_context(topk_weights=topk_weights,
|
|
topk_ids=topk_ids)
|
|
|
|
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
replace_allreduce=forward_context.sp_enabled,
|
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
|
quant_type=self.quant_type)
|
|
|
|
# Make sure the default stream waits for the gate stream to finish.
|
|
if self.multistream_overlap_gate:
|
|
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
|
|
|
|
if isinstance(hidden_states, tuple):
|
|
hidden_states, pertoken_scale = hidden_states
|
|
else:
|
|
pertoken_scale = None
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
pertoken_scale=pertoken_scale,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
global_num_experts=self.global_num_experts,
|
|
expert_map=self._expert_map,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
scoring_func=self.scoring_func,
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
activation=self.activation,
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
quantized_x_for_share=quantized_x_for_share,
|
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
|
shared_experts=None,
|
|
enable_force_load_balance=enable_force_load_balance,
|
|
log2phy=self.log2phy,
|
|
global_redundant_expert_num=self.global_redundant_expert_num,
|
|
mc2_mask=mc2_mask)
|
|
|
|
if isinstance(final_hidden_states, tuple):
|
|
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
|
if self.dynamic_eplb:
|
|
|
|
moe_load_stream = moe_load_async_stream()
|
|
cur_stream = torch.npu.current_stream()
|
|
|
|
moe_load_stream.wait_stream(cur_stream)
|
|
with npu_stream_switch(moe_load_stream):
|
|
self.moe_load += expert_tokens if group_list_type == 1 else \
|
|
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
|
cur_stream.wait_stream(moe_load_stream)
|
|
|
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
|
hidden_states=final_hidden_states,
|
|
reduce_results=self.reduce_results,
|
|
context_metadata=context_metadata)
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|
|
|
def __init__(
|
|
self,
|
|
shared_experts: torch.nn.Module,
|
|
gate: Optional[torch.nn.Module] = None,
|
|
use_overlapped: bool = True,
|
|
**kwargs,
|
|
):
|
|
AscendFusedMoE.__init__(self, **kwargs)
|
|
|
|
self._shared_experts = shared_experts
|
|
self.use_overlapped = use_overlapped
|
|
self.shared_expert_stream = None
|
|
ascend_config = get_ascend_config()
|
|
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
|
if enable_sp():
|
|
logger.info_once(
|
|
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
|
)
|
|
|
|
self._gate = gate
|
|
|
|
@property
|
|
def gate(self) -> Optional[torch.nn.Module]:
|
|
return self._gate if self.use_overlapped else None
|
|
|
|
@property
|
|
def is_internal_router(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def use_dp_chunking(self) -> bool:
|
|
"""This func routes to the chunked forward path using the FlashInfer Cutlass kernel
|
|
only when data parallelism (DP) is enabled. Thus just returning False in vllm-ascend
|
|
"""
|
|
return False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
shared_out, fused_out = AscendFusedMoE.forward(
|
|
self,
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
)
|
|
return shared_out, fused_out
|
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
shared_out = None
|
|
if not self.multistream_overlap_gate:
|
|
# Make sure the shared experts stream begins after hidden_states are ready.
|
|
if self.multistream_overlap_shared_expert:
|
|
shared_experts_calculation_stream(
|
|
).wait_stream( # type: ignore
|
|
torch.npu.current_stream())
|
|
with npu_stream_switch(
|
|
shared_experts_calculation_stream(),
|
|
enabled=self.multistream_overlap_shared_expert):
|
|
# Use a separate stream to run shared experts.
|
|
shared_out = self._shared_experts(hidden_states)
|
|
else:
|
|
set_flash_common3_context(shared_experts=self._shared_experts)
|
|
|
|
fused_output = AscendFusedMoE.forward_impl(
|
|
self,
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
)
|
|
|
|
if not self.multistream_overlap_gate:
|
|
# Make sure the default stream waits for the shared experts stream to finish.
|
|
if self.multistream_overlap_shared_expert:
|
|
torch.npu.current_stream().wait_stream(
|
|
shared_experts_calculation_stream())
|
|
|
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
|
forward_context = get_forward_context()
|
|
moe_comm_type = forward_context.moe_comm_type
|
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
|
and not shared_expert_dp_enabled():
|
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
|
else:
|
|
fc3_context = get_flash_common3_context()
|
|
assert fc3_context is not None
|
|
shared_out = fc3_context.shared_out
|
|
|
|
return shared_out, fused_output
|