[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -16,14 +16,14 @@
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
@@ -49,9 +49,8 @@ from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_all_reduce_merge_state,
|
||||
get_ascend_soc_version,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
@@ -122,149 +121,6 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
quant_mode = 0
|
||||
ep_rank_id = moe_parallel_config.ep_rank
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
}
|
||||
if need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
|
||||
group_list = expert_token_nums.to(torch.int64)
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[expand_x],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
# 1 means count mode, to avoid cumulative operation of the group list
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
tp_recv_counts = output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
}
|
||||
if enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": assist_info_for_combine,
|
||||
})
|
||||
if need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -318,248 +174,6 @@ def apply_mlp(
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
||||
minlength=global_num_experts)
|
||||
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
||||
-1).sum(-1)
|
||||
|
||||
gather_sizes = torch.empty_like(scatter_sizes)
|
||||
dist.all_to_all_single(gather_sizes,
|
||||
scatter_sizes,
|
||||
group=ep_group.device_group)
|
||||
scatter_size_list = scatter_sizes.cpu().tolist()
|
||||
gather_size_list = gather_sizes.cpu().tolist()
|
||||
|
||||
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
else:
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
resorted_idx = torch.argsort(sorted_idx)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
gather_size_list,
|
||||
scatter_size_list)
|
||||
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all_buffer(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
max_model_len: int,
|
||||
global_batch_size: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
||||
max_model_len // ep_group.world_size +
|
||||
1) * top_k * 2
|
||||
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
|
||||
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
||||
max_row_per_ep_rank, num_tokens, top_k)
|
||||
hidden_states_pad_idx = torch.zeros(
|
||||
expert_idx_buffer_scatter.shape,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
non_pad_len = torch.sum((expert_idx_buffer_scatter
|
||||
!= global_num_experts).to(torch.int32))
|
||||
hidden_states_pad_idx[expert_idx_buffer_scatter !=
|
||||
global_num_experts] = torch.arange(
|
||||
non_pad_len,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
||||
expert_idx_buffer_gather = torch.empty_like(
|
||||
expert_idx_buffer_scatter,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
hidden_states_buffer_gather = torch.empty_like(
|
||||
hidden_states_buffer_scatter,
|
||||
dtype=hidden_states_buffer_scatter.dtype,
|
||||
device=hidden_states_buffer_scatter.device)
|
||||
dist.all_to_all_single(expert_idx_buffer_gather,
|
||||
expert_idx_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
dist.all_to_all_single(hidden_states_buffer_gather,
|
||||
hidden_states_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
mask = expert_idx_buffer_gather != global_num_experts
|
||||
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
||||
global_num_experts // ep_group.world_size)
|
||||
hidden_states = hidden_states_buffer_gather[mask]
|
||||
idx_type = local_expert_idx.dtype
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
||||
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states = apply_mlp(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states_scatter = torch.zeros(
|
||||
(mask.shape[0], hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states_scatter[mask] = hidden_states
|
||||
hidden_states_gatter = torch.empty_like(
|
||||
hidden_states_scatter,
|
||||
dtype=hidden_states_scatter.dtype,
|
||||
device=hidden_states_scatter.device)
|
||||
dist.all_to_all_single(hidden_states_gatter,
|
||||
hidden_states_scatter,
|
||||
group=ep_group.device_group)
|
||||
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
|
||||
global_num_experts]
|
||||
if hidden_states_gatter.shape[0] != row_idx_len:
|
||||
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
hidden_states = hidden_states_gatter
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -651,188 +265,228 @@ def fused_experts_moge(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_all2allv(
|
||||
token_dispatcher,
|
||||
probs,
|
||||
routing_map,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
):
|
||||
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
|
||||
(share_experts_output, dispatched_input,
|
||||
tokens_per_expert) = (token_dispatcher.token_permutation(
|
||||
hidden_states, probs, routing_map))
|
||||
|
||||
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
|
||||
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(w1.shape)
|
||||
# print(hidden_states.shape)
|
||||
|
||||
original_shape = hidden_states.shape
|
||||
# assert len(original_shape) == 2
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfloat16 are supported"
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
token_indices = (torch.arange(num_tokens,
|
||||
device=device,
|
||||
dtype=torch.int64).unsqueeze(1).expand(
|
||||
-1, top_k).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = expert_map[experts_flat]
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
mask = local_experts_flat != -1
|
||||
filtered_weights = torch.where(
|
||||
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
||||
filtered_experts = torch.where(
|
||||
mask, local_experts_flat,
|
||||
torch.full_like(local_experts_flat,
|
||||
num_experts)).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
sorted_token_indices = token_indices[sort_indices]
|
||||
sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(num_experts + 1,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
token_counts = token_counts[:num_experts]
|
||||
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
||||
|
||||
# Rearrange hidden_states
|
||||
sorted_hidden_states = hidden_states[sorted_token_indices]
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=active_num)
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
if expert_map is not None:
|
||||
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros(*original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=dtype)
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||
def unified_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if get_forward_context().with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
scales = torch.ones_like(
|
||||
topk_weights) if apply_router_weight_on_input else topk_weights
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out_list,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=scales,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
|
||||
def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale_bias: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[Any] = None,
|
||||
shared_dequant_scale: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
token_dispatcher = get_forward_context().token_dispatcher
|
||||
|
||||
results = token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
expert_output = unified_apply_mlp(
|
||||
hidden_states=results["hidden_states"],
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=results["group_list"],
|
||||
dynamic_scale=results.get("dynamic_scale"),
|
||||
group_list_type=results.get("group_list_type"),
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=results.get("topk_scales"))
|
||||
final_hidden_states = token_dispatcher.token_combine(expert_output)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@@ -914,65 +568,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if enable_force_load_balance and not self.use_aclgraph:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
shared_experts=shared_experts,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif MOE_ALL2ALL_BUFFER:
|
||||
return fused_experts_with_all2all_buffer(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
max_model_len=self.max_model_len,
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
elif fused_moe_state == FusedMoEState.All2AllSeq:
|
||||
token_dispatcher = kwargs.get("token_dispatcher")
|
||||
return fused_experts_with_all2allv(
|
||||
token_dispatcher=token_dispatcher,
|
||||
probs=topk_weights,
|
||||
routing_map=topk_ids,
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
else:
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
return unified_fused_experts_eager(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
mc2_mask=kwargs.get(
|
||||
"mc2_mask", None))
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
@@ -1154,6 +759,19 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
with_quant = quant_config is not None
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
setup_token_dispatchers(
|
||||
ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_global_redundant_experts=self.global_redundant_expert_num,
|
||||
num_local_experts=self.local_num_experts,
|
||||
with_quant=with_quant)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
|
||||
@@ -22,21 +22,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
@@ -460,6 +457,31 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
return output, None
|
||||
|
||||
|
||||
_Dispatchers: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _register_token_dispatcher(dispatcher: Any):
|
||||
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
|
||||
|
||||
|
||||
def get_token_dispatcher(name: str):
|
||||
return _Dispatchers.get(name)
|
||||
|
||||
|
||||
def setup_token_dispatchers(ep_size: int, **kwargs):
|
||||
existing_dispatchers = set(_Dispatchers.keys())
|
||||
|
||||
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
|
||||
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
elif ep_size >= 16:
|
||||
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
|
||||
if "TokenDispatcherWithMC2" not in existing_dispatchers:
|
||||
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@@ -484,18 +506,19 @@ class MoETokenDispatcher(ABC):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@abstractmethod
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
@@ -516,40 +539,39 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
or self.torchair_graph_enabled)
|
||||
self.need_extra_args = (
|
||||
get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.output = None
|
||||
self.dynamic_scale = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0):
|
||||
quant_mode = 0
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
if self.with_quant:
|
||||
quant_mode = 2
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
else:
|
||||
quant_mode = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
@@ -575,28 +597,30 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
self.mc2_mask = mc2_mask
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
@@ -606,28 +630,27 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, self.dynamic_scale, self.assist_info_for_combine, \
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, \
|
||||
expert_token_nums, self.ep_recv_counts = self.output[0:5]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(
|
||||
hidden_states)
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 1
|
||||
return group_list_type, expand_x, expert_token_nums
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
}
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
@@ -635,8 +658,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
moe_expert_num = len(self.expert_map)
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
@@ -677,7 +698,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -685,7 +706,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
@@ -695,15 +715,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
return hidden_states
|
||||
else:
|
||||
if self.with_quant:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(self.shared_act, hidden_states)
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(self.shared_act, hidden_states)
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
@@ -711,13 +727,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = kwargs.get(
|
||||
"apply_router_weight_on_input")
|
||||
self.top_k = kwargs.get("top_k")
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
ep_size = kwargs.get("ep_size")
|
||||
if ep_size is not None:
|
||||
self.num_experts_local = self.num_experts // ep_size
|
||||
self.num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
@@ -727,20 +739,20 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
self.original_shape = hidden_states.shape
|
||||
# assert len(original_shape) == 2
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
@@ -748,9 +760,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfsloat16 are supported"
|
||||
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
@@ -803,19 +813,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
sorted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
if self.with_quant:
|
||||
group_list_type = 1
|
||||
expert_tokens = token_counts
|
||||
else:
|
||||
expert_tokens = torch.cumsum(token_counts,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list_type = 0
|
||||
else:
|
||||
row_idx_len = num_tokens * self.top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(self.top_k,
|
||||
-1).permute(
|
||||
1, 0).contiguous())
|
||||
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
|
||||
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
@@ -827,18 +831,23 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expanded_expert_idx, self.num_experts_local)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
return group_list_type, sorted_hidden_states, expert_tokens
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
assert self.original_shape is not None
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
if self.expert_map is not None:
|
||||
assert self.mask is not None
|
||||
assert self.sorted_token_indices is not None
|
||||
assert self.sorted_weights is not None
|
||||
|
||||
weighted_down_out = hidden_states * \
|
||||
self.sorted_weights.unsqueeze(1)
|
||||
|
||||
@@ -887,7 +896,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expanded_src_to_dst_row=self.expanded_row_idx,
|
||||
export_for_source_row=self.topk_ids,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@@ -895,29 +903,27 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MoETokenDispatcher, self).__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = kwargs.get(
|
||||
"apply_router_weight_on_input")
|
||||
ep_size = kwargs.get("ep_size")
|
||||
self.local_ep = ep_size
|
||||
assert self.local_ep is not None
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.local_ep = 1
|
||||
self.local_num_experts = self.num_experts // self.local_ep
|
||||
self.local_num_group = self.top_k // self.local_ep
|
||||
self.bsz = None
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
@@ -932,7 +938,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
|
||||
self.sorted_hidden_states = hidden_states.index_select(
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, self.sorted_topk_ids // self.local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
@@ -942,15 +948,20 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
num_tokens_per_expert = (
|
||||
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
self.topk_scales = topk_weights.view(-1).index_select(
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, self.sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
return hidden_states, group_list
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.local_ep is not None
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
unsorted_hidden_states = hidden_states.index_select(
|
||||
@@ -1009,18 +1020,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[torch.Tensor] = None,
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
|
||||
Reference in New Issue
Block a user