[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088)
It comes from 0.9.1dev
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized
RL training sence & alltoallv support dpo (#1547)
- vLLM version: v0.10.0
- vLLM main:
97608dc276
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
Signed-off-by: taoxudonghaha <justsheldon@163.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com>
Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: TaoYu Chen <ctynb@qq.com>
Co-authored-by: taoxudonghaha <justsheldon@163.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
62
vllm_ascend/ops/comm_utils.py
Normal file
62
vllm_ascend/ops/comm_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
else:
|
||||
# Unequal split (all2all-v)
|
||||
a2a_out = input_.new_empty(
|
||||
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
if event:
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
return input_, a2a_out, handle
|
||||
@@ -16,12 +16,14 @@
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -35,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
@@ -45,6 +48,8 @@ from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
@@ -273,11 +278,13 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@@ -299,9 +306,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
assert len(hidden_states_wrapper) == 1
|
||||
hidden_states = hidden_states_wrapper.pop()
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -329,6 +333,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -543,10 +549,7 @@ def fused_experts_with_all2all_buffer(
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
hidden_states = apply_mlp(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
@@ -682,6 +685,24 @@ def fused_experts_moge(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_all2allv(
|
||||
token_dispatcher,
|
||||
probs,
|
||||
routing_map,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
):
|
||||
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
|
||||
(share_experts_output, dispatched_input,
|
||||
tokens_per_expert) = (token_dispatcher.token_permutation(
|
||||
hidden_states, probs, routing_map))
|
||||
|
||||
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
|
||||
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -1124,6 +1145,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
elif fused_moe_state == FusedMoEState.All2AllSeq:
|
||||
token_dispatcher = kwargs.get("token_dispatcher")
|
||||
return fused_experts_with_all2allv(
|
||||
token_dispatcher=token_dispatcher,
|
||||
probs=topk_weights,
|
||||
routing_map=topk_ids,
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
else:
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -1275,6 +1306,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, AscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
@@ -1414,6 +1464,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
@@ -1430,11 +1481,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
@@ -1491,6 +1542,83 @@ class AscendFusedMoE(FusedMoE):
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance)
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AscendSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = (
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
0
vllm_ascend/ops/moe_dispatcher/__init__.py
Normal file
453
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
453
vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Normal file
@@ -0,0 +1,453 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
|
||||
|
||||
class MoEDispatcherConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.num_local_experts: int = 0
|
||||
self.num_moe_experts: int = 0
|
||||
self.moe_pad_expert_input_to_capacity: bool = False
|
||||
self.moe_expert_capacity_factor: Optional[float] = None
|
||||
self.moe_router_topk: int = 2
|
||||
self.moe_grouped_gemm: bool = False
|
||||
self.group_topk: int = 0
|
||||
self.num_groups: int = 1
|
||||
self.expert_bias: torch.Tensor = None
|
||||
self.scaling_factor: Optional[float] = None
|
||||
self.is_fused: bool = True
|
||||
|
||||
def set_num_local_experts(self, num_local_experts):
|
||||
self.num_local_experts = num_local_experts
|
||||
return self
|
||||
|
||||
def set_num_moe_experts(self, num_moe_experts):
|
||||
self.num_moe_experts = num_moe_experts
|
||||
return self
|
||||
|
||||
def set_moe_pad_expert_input_to_capacity(self,
|
||||
moe_pad_expert_input_to_capacity):
|
||||
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity
|
||||
return self
|
||||
|
||||
def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor):
|
||||
self.moe_expert_capacity_factor = moe_expert_capacity_factor
|
||||
return self
|
||||
|
||||
def set_moe_router_topk(self, moe_router_topk):
|
||||
self.moe_router_topk = moe_router_topk
|
||||
return self
|
||||
|
||||
def set_moe_grouped_gemm(self, moe_grouped_gemm):
|
||||
self.moe_grouped_gemm = moe_grouped_gemm
|
||||
return self
|
||||
|
||||
def set_group_topk(self, group_topk):
|
||||
self.group_topk = group_topk
|
||||
return self
|
||||
|
||||
def set_num_groups(self, num_groups):
|
||||
self.num_groups = num_groups
|
||||
return self
|
||||
|
||||
def set_expert_bias(self, expert_bias):
|
||||
self.expert_bias = expert_bias
|
||||
return self
|
||||
|
||||
def set_scaling_factor(self, scaling_factor):
|
||||
self.scaling_factor = scaling_factor
|
||||
return self
|
||||
|
||||
def set_is_fused(self, is_fused):
|
||||
self.is_fused = is_fused
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self
|
||||
|
||||
|
||||
class MoEDispatcher:
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.config = config
|
||||
self.shared_experts = None
|
||||
|
||||
def set_shared_experts(self, shared_experts):
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@property
|
||||
def tp_ep_group(self):
|
||||
"""Get expert tensor and model parallel group."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_ep_size(self):
|
||||
return 1
|
||||
|
||||
|
||||
class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
overlap_stream = None
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig):
|
||||
"""
|
||||
Initialize the AlltoAllSeq token dispatcher.
|
||||
|
||||
Args:
|
||||
config (MoEDispatcherConfig): Configuration for the transformer model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.num_local_experts = config.num_local_experts
|
||||
self.config = config
|
||||
# use MOEAlltoAllSEQTokenDispatcher to init
|
||||
|
||||
self.hidden_shape = None
|
||||
self.num_input_tokens = None
|
||||
self.num_experts = config.num_moe_experts
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
self.probs = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.routing_map = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# A cuda stream synchronization is needed in self.token_permutation()
|
||||
# in some cases, because there are several non-blocking DtoH data
|
||||
# transfers called in self.preprocess(). The synchronization happens
|
||||
# at different points based on MoE settings as late as possible.
|
||||
# Valid sync points are "before_permutation_1", "before_ep_alltoall",
|
||||
# "before_finish", and "no_sync".
|
||||
self.device_sync_point = "no_sync"
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.cached_permutated_local_input_tokens = None
|
||||
self.cached_global_input_tokens = None
|
||||
self.cached_shared_expert_output = None
|
||||
self.tokens_per_expert = None
|
||||
self.perm1_finish_event = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None:
|
||||
MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream()
|
||||
|
||||
self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream
|
||||
|
||||
def preprocess(self,
|
||||
indices: torch.Tensor,
|
||||
with_sync=True) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess routing map for AlltoAll communication and token permutation.
|
||||
This method computes the number of tokens assigned to each expert based on
|
||||
the routing map. It also initializes the necessary data structures for
|
||||
AlltoAll communication, such as input and output splits, and the mapping
|
||||
between global tokens and local experts.
|
||||
|
||||
Args:
|
||||
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
|
||||
[num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
|
||||
"""
|
||||
num_local_tokens_per_expert = torch.histc(indices,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
# num_local_tokens_per_expert: [num_experts]
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = indices.numel()
|
||||
if self.ep_size > 1 or self.num_local_experts > 1:
|
||||
# Token dropless and enable ep. A synchronization is needed before expert parallel
|
||||
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
|
||||
self.device_sync_point = "before_ep_alltoall"
|
||||
else:
|
||||
# Token dropless and no ep. A synchronization is needed to get the
|
||||
# `tokens_per_expert` CPU value.
|
||||
self.device_sync_point = "before_finish"
|
||||
|
||||
if ep_size > 1:
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size, self.num_local_experts).sum(axis=1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum."
|
||||
)
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
else:
|
||||
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
|
||||
-1, self.num_experts)
|
||||
num_tokens_per_local_expert = num_local_tokens_per_expert
|
||||
|
||||
if self.num_local_experts > 1 and with_sync:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.device_sync_point = "no_sync"
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
routing_map: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Dispatch tokens to local experts using AlltoAllSeq communication.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input token embeddings.
|
||||
probs (torch.Tensor): Probs of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- Permuted token embeddings for local experts.
|
||||
- Number of tokens per expert.
|
||||
"""
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.probs = probs
|
||||
self.top_indices = routing_map
|
||||
assert probs.dim() == 2, "Expected 2D tensor for probs"
|
||||
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
# Permutation 1: input to AlltoAll input
|
||||
def alltoall_token_permutation1(hidden_states, routing_map):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self.preprocess(routing_map)
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = all_to_all_sp2hp(hidden_states,
|
||||
group=self.tp_ep_group)
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
if self.device_sync_point == "before_permutation_1":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=self.top_indices,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1(
|
||||
hidden_states, routing_map)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
# permute 1
|
||||
|
||||
ep_group = self.ep_group
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
if self.device_sync_point == "before_ep_alltoall":
|
||||
torch.npu.current_stream().synchronize()
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
ep_group,
|
||||
)
|
||||
|
||||
# shared experts compute
|
||||
if self.shared_experts is not None:
|
||||
(share_experts_output), *_ = self.shared_experts(hidden_states)
|
||||
else:
|
||||
share_experts_output = None
|
||||
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_permutation2(global_input_tokens):
|
||||
# Permutation 2: Sort tokens by local expert.
|
||||
if self.num_local_experts > 1:
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
|
||||
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
|
||||
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
|
||||
if self.tp_ep_size > 1 and self.config.moe_grouped_gemm:
|
||||
global_input_tokens = all_gather_last_dim_from_tensor_parallel_region(
|
||||
global_input_tokens, self.tp_ep_group)
|
||||
if self.device_sync_point == "before_finish":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
return global_input_tokens
|
||||
|
||||
# token premute2 input
|
||||
global_input_tokens = alltoall_token_permutation2(global_input_tokens)
|
||||
|
||||
return share_experts_output, global_input_tokens, tokens_per_expert
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
"""
|
||||
Reverse the token permutation to restore the original order.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Output from local experts.
|
||||
bias (torch.Tensor, optional): Bias tensor (not supported).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- Unpermuted token embeddings in the original order.
|
||||
- None (bias is not supported).
|
||||
"""
|
||||
|
||||
def alltoall_token_unpermutation1(hidden_states):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
# Perform tensor parallel Reduce-Scatter
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(
|
||||
hidden_states, group=self.tp_ep_group)
|
||||
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states,
|
||||
self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
hidden_states = alltoall_token_unpermutation1(hidden_states)
|
||||
|
||||
ep_group = self.ep_group
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits, ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_unpermutation2(permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.
|
||||
to(torch.int32),
|
||||
probs=self.probs,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Perform tensor parallel AlltoAll communication
|
||||
# output: [S*B, H/TP] -> [S*B/TP, H]
|
||||
if self.tp_ep_size > 1:
|
||||
output = all_to_all_hp2sp(output, self.tp_ep_group)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
|
||||
output = alltoall_token_unpermutation2(permutated_local_input_tokens)
|
||||
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
|
||||
return output, None
|
||||
Reference in New Issue
Block a user