[Kernel] Add moe normal ops (#4810)
### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.
- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.
2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.
- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include <torch/torch.h>
|
||||
#include <torch_npu/csrc/core/npu/NPUStream.h>
|
||||
#include <torch_npu/csrc/framework/OpCommand.h>
|
||||
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
|
||||
#include "torch_npu/csrc/core/npu/NPUGuard.h"
|
||||
#include <torch_npu/csrc/npu/Module.h>
|
||||
#include "acl/acl.h"
|
||||
@@ -838,6 +839,246 @@ std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm(
|
||||
return {output, add_out};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts,
|
||||
int64_t num_ranks) {
|
||||
TORCH_BIND_ASSERT(topk_idx.dim() == 2);
|
||||
TORCH_BIND_ASSERT(topk_idx.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_experts > 0);
|
||||
|
||||
const int num_tokens = topk_idx.size(0);
|
||||
const int num_topk = topk_idx.size(1);
|
||||
|
||||
auto device = topk_idx.device();
|
||||
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
|
||||
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
|
||||
auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
|
||||
|
||||
EXEC_NPU_CMD(aclnnDispatchLayout,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_ranks,
|
||||
num_experts,
|
||||
num_topk,
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank);
|
||||
|
||||
auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool);
|
||||
|
||||
return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> dispatch_prefill(
|
||||
const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
|
||||
const at::Tensor& num_tokens_per_rank, const at::Tensor& is_token_in_rank, at::Tensor& num_tokens_per_expert,
|
||||
int64_t num_worst_tokens, c10::string_view groupEp, int64_t rank, int64_t num_ranks) {
|
||||
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char* group_ep_ptr = &group_ep_chrs[0];
|
||||
at::Tensor new_x = x;
|
||||
|
||||
// Type checks
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.scalar_type() == at::kBool);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.scalar_type() == at::kInt);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.scalar_type() == at::kInt);
|
||||
|
||||
// Shape and contiguous checks
|
||||
TORCH_BIND_ASSERT(new_x.dim() == 2 and new_x.is_contiguous());
|
||||
// TORCH_BIND_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
|
||||
TORCH_BIND_ASSERT(is_token_in_rank.size(0) == new_x.size(0) and is_token_in_rank.size(1) == num_ranks);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.dim() == 1 and num_tokens_per_expert.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens_per_expert.size(0) % num_ranks == 0);
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.dim() == 1 and num_tokens_per_rank.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens_per_rank.size(0) == num_ranks);
|
||||
|
||||
auto num_tokens = static_cast<int>(new_x.size(0));
|
||||
auto hidden = static_cast<int>(new_x.size(1));
|
||||
auto num_experts = static_cast<int64_t>(num_tokens_per_expert.size(0));
|
||||
auto num_local_experts = static_cast<int>(num_experts / num_ranks);
|
||||
|
||||
// Top-k checks
|
||||
int num_topk = 0;
|
||||
num_topk = static_cast<int>(topk_idx.size(1));
|
||||
TORCH_BIND_ASSERT(num_experts > 0);
|
||||
TORCH_BIND_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
|
||||
TORCH_BIND_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
|
||||
TORCH_BIND_ASSERT(num_tokens == topk_idx.size(0));
|
||||
TORCH_BIND_ASSERT(num_topk == topk_weights.size(1));
|
||||
TORCH_BIND_ASSERT(topk_weights.scalar_type() == at::kFloat);
|
||||
|
||||
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
|
||||
|
||||
auto send_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
|
||||
int64_t send_count = send_per_group * num_local_experts * num_ranks;
|
||||
|
||||
auto send_data_offset = at::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
|
||||
at::Tensor recv_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
|
||||
|
||||
int64_t local_rank_size = num_ranks;
|
||||
int64_t local_rank_id = rank % local_rank_size;
|
||||
|
||||
EXEC_NPU_CMD(aclnnNotifyDispatch,
|
||||
send_data,
|
||||
num_tokens_per_expert,
|
||||
send_count,
|
||||
num_tokens,
|
||||
group_ep_ptr, // commGroup
|
||||
num_ranks, // rankSize
|
||||
rank, // rankId
|
||||
local_rank_size,
|
||||
local_rank_id,
|
||||
send_data_offset,
|
||||
recv_data);
|
||||
|
||||
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
||||
std::vector<int32_t> local_expert_acc(num_experts, 0);
|
||||
auto send_token_idx_cpu = at::empty({num_tokens, num_topk}, options_cpu);
|
||||
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
|
||||
|
||||
auto topk_idx_cpu = topk_idx.to(at::kCPU);
|
||||
auto topk_idx_ptr = topk_idx_cpu.data_ptr<int64_t>();
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
for (int j = 0; j < num_topk; ++j) {
|
||||
int64_t expert_idx = topk_idx_ptr[i * num_topk + j];
|
||||
if (expert_idx >= 0) {
|
||||
int32_t cnt = local_expert_acc[expert_idx];
|
||||
send_token_idx_ptr[i * num_topk + j] = cnt;
|
||||
local_expert_acc[expert_idx]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_BIND_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
|
||||
TORCH_BIND_ASSERT(recv_data.size(0) % num_experts == 0);
|
||||
at::Tensor recv_offset_cpu = at::empty({num_experts}, options_cpu);
|
||||
at::Tensor recv_count_cpu = at::empty({num_experts}, options_cpu);
|
||||
auto recv_data_cpu = recv_data.to(at::kCPU);
|
||||
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
|
||||
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
|
||||
auto recv_offset_ptr = recv_offset_cpu.data_ptr<int>();
|
||||
int64_t total_recv_tokens = 0;
|
||||
int64_t num_max_dispatch_tokens_per_rank = 0;
|
||||
std::vector<int64_t> num_recv_tokens_per_expert_list;
|
||||
|
||||
for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) {
|
||||
int64_t local_expert_recv_tokens = 0;
|
||||
for (int64_t src_rank = 0; src_rank < num_ranks; ++src_rank) {
|
||||
int64_t index = local_e * num_ranks + src_rank;
|
||||
int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e);
|
||||
|
||||
int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for
|
||||
// this global_expert
|
||||
int recv_off = recv_data_ptr[pair_idx + 1]; // offset in that src_rank's window
|
||||
int64_t send_num_tokens = recv_data_ptr[pair_idx + 2]; // all bs from rank
|
||||
|
||||
total_recv_tokens += recv_cnt;
|
||||
recv_count_ptr[index] = total_recv_tokens;
|
||||
recv_offset_ptr[index] = recv_off;
|
||||
num_max_dispatch_tokens_per_rank = std::max(num_max_dispatch_tokens_per_rank, send_num_tokens);
|
||||
|
||||
local_expert_recv_tokens += recv_cnt;
|
||||
}
|
||||
num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens);
|
||||
}
|
||||
auto option = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
|
||||
at::Tensor num_recv_tokens_per_expert = torch::from_blob(
|
||||
num_recv_tokens_per_expert_list.data(), {static_cast<int64_t>(num_recv_tokens_per_expert_list.size())}, option)
|
||||
.clone();
|
||||
|
||||
at::Tensor expert_ids = topk_idx.to(at::kInt);
|
||||
int64_t tp_size = 1;
|
||||
int64_t tp_rank = 0;
|
||||
int64_t quant_mode = 0;
|
||||
int64_t global_bs = static_cast<int64_t>(
|
||||
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)));
|
||||
|
||||
auto send_token_idx = send_token_idx_cpu.to(x.device());
|
||||
auto recv_offset = recv_offset_cpu.to(x.device());
|
||||
auto recv_count = recv_count_cpu.to(x.device());
|
||||
|
||||
int total_cnt = total_recv_tokens;
|
||||
if (total_cnt == 0) {
|
||||
total_cnt = 1;
|
||||
}
|
||||
auto expandx_out = at::empty({total_cnt, hidden}, x.options());
|
||||
auto dynamic_scales_out = at::empty({total_cnt}, at::dtype(at::kFloat).device(x.device()));
|
||||
auto expand_idx_out = at::empty({total_cnt * 3}, at::dtype(at::kInt).device(x.device()));
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeDispatchNormal,
|
||||
new_x,
|
||||
expert_ids,
|
||||
send_data_offset,
|
||||
send_token_idx,
|
||||
recv_offset,
|
||||
recv_count,
|
||||
group_ep_ptr, // commGroup
|
||||
num_ranks, // rankSize
|
||||
rank, // rankId
|
||||
group_ep_ptr,
|
||||
tp_size,
|
||||
tp_rank,
|
||||
num_experts,
|
||||
quant_mode,
|
||||
global_bs,
|
||||
expandx_out,
|
||||
dynamic_scales_out,
|
||||
expand_idx_out);
|
||||
|
||||
// Return values
|
||||
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
|
||||
}
|
||||
|
||||
at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
|
||||
const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp,
|
||||
int64_t rank, int64_t num_ranks) {
|
||||
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
|
||||
group_ep_chrs.push_back('\0');
|
||||
char* group_ep_ptr = &group_ep_chrs[0];
|
||||
|
||||
TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous());
|
||||
at::Tensor recv_x = x;
|
||||
|
||||
at::Tensor topk_idx_p = topk_idx;
|
||||
|
||||
auto topk_idx_int32 = topk_idx_p.to(at::kInt);
|
||||
at::Tensor expand_ids = topk_idx_int32;
|
||||
at::Tensor token_src_info = src_idx;
|
||||
at::Tensor ep_send_counts = send_head;
|
||||
auto device = x.device();
|
||||
|
||||
const int num_tokens = topk_idx_p.size(0);
|
||||
const int num_topk = topk_idx_p.size(1);
|
||||
|
||||
int64_t hidden = static_cast<int>(recv_x.size(1));
|
||||
at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device));
|
||||
int64_t tp_world_size = 1;
|
||||
int64_t tp_rankId = 0;
|
||||
int64_t moe_expert_number = send_head.size(0);
|
||||
int64_t global_bs = topk_idx_p.size(0) * num_ranks;
|
||||
|
||||
// Combine data
|
||||
auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options());
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeCombineNormal,
|
||||
recv_x,
|
||||
token_src_info,
|
||||
ep_send_counts,
|
||||
topk_weights,
|
||||
tp_send_counts,
|
||||
group_ep_ptr,
|
||||
num_ranks,
|
||||
rank,
|
||||
group_ep_ptr,
|
||||
tp_world_size,
|
||||
tp_rankId,
|
||||
moe_expert_number,
|
||||
global_bs,
|
||||
combined_x);
|
||||
|
||||
return combined_x;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -955,4 +1196,25 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
|
||||
str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)");
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm);
|
||||
|
||||
ops.def("get_dispatch_layout(Tensor topk_idx, int num_experts, int "
|
||||
"num_ranks) -> (Tensor num_tokens_per_rank, Tensor "
|
||||
"num_tokens_per_expert, Tensor is_token_in_rank_bool)");
|
||||
ops.impl("get_dispatch_layout", torch::kPrivateUse1,
|
||||
&vllm_ascend::get_dispatch_layout);
|
||||
|
||||
ops.def(
|
||||
"dispatch_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
|
||||
"Tensor num_tokens_per_rank, Tensor is_token_in_rank, Tensor "
|
||||
"num_tokens_per_expert, int num_worst_tokens, str groupEp, int rank, "
|
||||
"int num_ranks) -> (Tensor expandx_out, Tensor expand_idx_out, Tensor "
|
||||
"recv_count, Tensor num_recv_tokens_per_expert)");
|
||||
ops.impl("dispatch_prefill", torch::kPrivateUse1,
|
||||
&vllm_ascend::dispatch_prefill);
|
||||
|
||||
ops.def("combine_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
|
||||
"Tensor src_idx, Tensor send_head, str grouEp, int rank, int "
|
||||
"num_ranks) -> Tensor");
|
||||
ops.impl("combine_prefill", torch::kPrivateUse1,
|
||||
&vllm_ascend::combine_prefill);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user