[Kernel] Add moe normal ops (#4810)

### What this PR does / why we need it?
1.Add the implementation of normal Aclnn operators: MoeCombineNormal,
MoeDispatchNormal, NotifyDispatch,and DispatchLayout.

- MoeCombineNormal: Implements the combine logic within MoE operations.
- MoeDispatchNormal: Implements the dispatch logic within MoE
operations.
- NotifyDispatch: Exchanges topk_idx information among different ranks
to calculate the device memory required for the dispatch stage.
- DispatchLayout: Used to calculate information related to the device
memory layout for the dispatch stage.

2.Provide PyTorch interfaces for normal operators—get_dispatch_layout,
dispatch_prefill, and combine_prefill—to be used for MoE communication
during the prefill stage in vLLM.

- get_dispatch_layout: Calculates information related to the device
memory layout for the dispatch operator, and is called before
dispatch_prefill.
- dispatch_prefill: Initiates the dispatch operation.
- combine_prefill: Initiates the combine operation.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The functionality has already been validated using the local Qwen model.
Test cases will be added after support for multi-NPU use cases in the CI
pipeline is finalized.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
This commit is contained in:
shiro-zzzz
2025-12-10 17:15:28 +08:00
committed by GitHub
parent c77dca54b2
commit bd8be2e759
39 changed files with 5365 additions and 4 deletions

View File

@@ -20,6 +20,7 @@
#include <torch/torch.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include <torch_npu/csrc/npu/Module.h>
#include "acl/acl.h"
@@ -838,6 +839,246 @@ std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm(
return {output, add_out};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_dispatch_layout(const at::Tensor& topk_idx, int64_t num_experts,
int64_t num_ranks) {
TORCH_BIND_ASSERT(topk_idx.dim() == 2);
TORCH_BIND_ASSERT(topk_idx.is_contiguous());
TORCH_BIND_ASSERT(num_experts > 0);
const int num_tokens = topk_idx.size(0);
const int num_topk = topk_idx.size(1);
auto device = topk_idx.device();
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
EXEC_NPU_CMD(aclnnDispatchLayout,
topk_idx,
num_tokens,
num_ranks,
num_experts,
num_topk,
num_tokens_per_rank,
num_tokens_per_expert,
is_token_in_rank);
auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool);
return std::make_tuple(num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank_bool);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> dispatch_prefill(
const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
const at::Tensor& num_tokens_per_rank, const at::Tensor& is_token_in_rank, at::Tensor& num_tokens_per_expert,
int64_t num_worst_tokens, c10::string_view groupEp, int64_t rank, int64_t num_ranks) {
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
group_ep_chrs.push_back('\0');
char* group_ep_ptr = &group_ep_chrs[0];
at::Tensor new_x = x;
// Type checks
TORCH_BIND_ASSERT(is_token_in_rank.scalar_type() == at::kBool);
TORCH_BIND_ASSERT(num_tokens_per_expert.scalar_type() == at::kInt);
TORCH_BIND_ASSERT(num_tokens_per_rank.scalar_type() == at::kInt);
// Shape and contiguous checks
TORCH_BIND_ASSERT(new_x.dim() == 2 and new_x.is_contiguous());
// TORCH_BIND_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
TORCH_BIND_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
TORCH_BIND_ASSERT(is_token_in_rank.size(0) == new_x.size(0) and is_token_in_rank.size(1) == num_ranks);
TORCH_BIND_ASSERT(num_tokens_per_expert.dim() == 1 and num_tokens_per_expert.is_contiguous());
TORCH_BIND_ASSERT(num_tokens_per_expert.size(0) % num_ranks == 0);
TORCH_BIND_ASSERT(num_tokens_per_rank.dim() == 1 and num_tokens_per_rank.is_contiguous());
TORCH_BIND_ASSERT(num_tokens_per_rank.size(0) == num_ranks);
auto num_tokens = static_cast<int>(new_x.size(0));
auto hidden = static_cast<int>(new_x.size(1));
auto num_experts = static_cast<int64_t>(num_tokens_per_expert.size(0));
auto num_local_experts = static_cast<int>(num_experts / num_ranks);
// Top-k checks
int num_topk = 0;
num_topk = static_cast<int>(topk_idx.size(1));
TORCH_BIND_ASSERT(num_experts > 0);
TORCH_BIND_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
TORCH_BIND_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
TORCH_BIND_ASSERT(num_tokens == topk_idx.size(0));
TORCH_BIND_ASSERT(num_topk == topk_weights.size(1));
TORCH_BIND_ASSERT(topk_weights.scalar_type() == at::kFloat);
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
auto send_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
int64_t send_count = send_per_group * num_local_experts * num_ranks;
auto send_data_offset = at::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
at::Tensor recv_data = at::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
int64_t local_rank_size = num_ranks;
int64_t local_rank_id = rank % local_rank_size;
EXEC_NPU_CMD(aclnnNotifyDispatch,
send_data,
num_tokens_per_expert,
send_count,
num_tokens,
group_ep_ptr, // commGroup
num_ranks, // rankSize
rank, // rankId
local_rank_size,
local_rank_id,
send_data_offset,
recv_data);
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
std::vector<int32_t> local_expert_acc(num_experts, 0);
auto send_token_idx_cpu = at::empty({num_tokens, num_topk}, options_cpu);
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
auto topk_idx_cpu = topk_idx.to(at::kCPU);
auto topk_idx_ptr = topk_idx_cpu.data_ptr<int64_t>();
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < num_topk; ++j) {
int64_t expert_idx = topk_idx_ptr[i * num_topk + j];
if (expert_idx >= 0) {
int32_t cnt = local_expert_acc[expert_idx];
send_token_idx_ptr[i * num_topk + j] = cnt;
local_expert_acc[expert_idx]++;
}
}
}
TORCH_BIND_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
TORCH_BIND_ASSERT(recv_data.size(0) % num_experts == 0);
at::Tensor recv_offset_cpu = at::empty({num_experts}, options_cpu);
at::Tensor recv_count_cpu = at::empty({num_experts}, options_cpu);
auto recv_data_cpu = recv_data.to(at::kCPU);
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
auto recv_offset_ptr = recv_offset_cpu.data_ptr<int>();
int64_t total_recv_tokens = 0;
int64_t num_max_dispatch_tokens_per_rank = 0;
std::vector<int64_t> num_recv_tokens_per_expert_list;
for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) {
int64_t local_expert_recv_tokens = 0;
for (int64_t src_rank = 0; src_rank < num_ranks; ++src_rank) {
int64_t index = local_e * num_ranks + src_rank;
int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e);
int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for
// this global_expert
int recv_off = recv_data_ptr[pair_idx + 1]; // offset in that src_rank's window
int64_t send_num_tokens = recv_data_ptr[pair_idx + 2]; // all bs from rank
total_recv_tokens += recv_cnt;
recv_count_ptr[index] = total_recv_tokens;
recv_offset_ptr[index] = recv_off;
num_max_dispatch_tokens_per_rank = std::max(num_max_dispatch_tokens_per_rank, send_num_tokens);
local_expert_recv_tokens += recv_cnt;
}
num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens);
}
auto option = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
at::Tensor num_recv_tokens_per_expert = torch::from_blob(
num_recv_tokens_per_expert_list.data(), {static_cast<int64_t>(num_recv_tokens_per_expert_list.size())}, option)
.clone();
at::Tensor expert_ids = topk_idx.to(at::kInt);
int64_t tp_size = 1;
int64_t tp_rank = 0;
int64_t quant_mode = 0;
int64_t global_bs = static_cast<int64_t>(
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)));
auto send_token_idx = send_token_idx_cpu.to(x.device());
auto recv_offset = recv_offset_cpu.to(x.device());
auto recv_count = recv_count_cpu.to(x.device());
int total_cnt = total_recv_tokens;
if (total_cnt == 0) {
total_cnt = 1;
}
auto expandx_out = at::empty({total_cnt, hidden}, x.options());
auto dynamic_scales_out = at::empty({total_cnt}, at::dtype(at::kFloat).device(x.device()));
auto expand_idx_out = at::empty({total_cnt * 3}, at::dtype(at::kInt).device(x.device()));
EXEC_NPU_CMD(aclnnMoeDispatchNormal,
new_x,
expert_ids,
send_data_offset,
send_token_idx,
recv_offset,
recv_count,
group_ep_ptr, // commGroup
num_ranks, // rankSize
rank, // rankId
group_ep_ptr,
tp_size,
tp_rank,
num_experts,
quant_mode,
global_bs,
expandx_out,
dynamic_scales_out,
expand_idx_out);
// Return values
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
}
at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights,
const at::Tensor& src_idx, const at::Tensor& send_head, c10::string_view groupEp,
int64_t rank, int64_t num_ranks) {
std::vector<char> group_ep_chrs(groupEp.begin(), groupEp.end());
group_ep_chrs.push_back('\0');
char* group_ep_ptr = &group_ep_chrs[0];
TORCH_BIND_ASSERT(x.dim() == 2 and x.is_contiguous());
at::Tensor recv_x = x;
at::Tensor topk_idx_p = topk_idx;
auto topk_idx_int32 = topk_idx_p.to(at::kInt);
at::Tensor expand_ids = topk_idx_int32;
at::Tensor token_src_info = src_idx;
at::Tensor ep_send_counts = send_head;
auto device = x.device();
const int num_tokens = topk_idx_p.size(0);
const int num_topk = topk_idx_p.size(1);
int64_t hidden = static_cast<int>(recv_x.size(1));
at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device));
int64_t tp_world_size = 1;
int64_t tp_rankId = 0;
int64_t moe_expert_number = send_head.size(0);
int64_t global_bs = topk_idx_p.size(0) * num_ranks;
// Combine data
auto combined_x = torch::empty({topk_weights.size(0), hidden}, x.options());
EXEC_NPU_CMD(aclnnMoeCombineNormal,
recv_x,
token_src_info,
ep_send_counts,
topk_weights,
tp_send_counts,
group_ep_ptr,
num_ranks,
rank,
group_ep_ptr,
tp_world_size,
tp_rankId,
moe_expert_number,
global_bs,
combined_x);
return combined_x;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -955,4 +1196,25 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)");
ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm);
ops.def("get_dispatch_layout(Tensor topk_idx, int num_experts, int "
"num_ranks) -> (Tensor num_tokens_per_rank, Tensor "
"num_tokens_per_expert, Tensor is_token_in_rank_bool)");
ops.impl("get_dispatch_layout", torch::kPrivateUse1,
&vllm_ascend::get_dispatch_layout);
ops.def(
"dispatch_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
"Tensor num_tokens_per_rank, Tensor is_token_in_rank, Tensor "
"num_tokens_per_expert, int num_worst_tokens, str groupEp, int rank, "
"int num_ranks) -> (Tensor expandx_out, Tensor expand_idx_out, Tensor "
"recv_count, Tensor num_recv_tokens_per_expert)");
ops.impl("dispatch_prefill", torch::kPrivateUse1,
&vllm_ascend::dispatch_prefill);
ops.def("combine_prefill(Tensor x, Tensor topk_idx, Tensor topk_weights, "
"Tensor src_idx, Tensor send_head, str grouEp, int rank, int "
"num_ranks) -> Tensor");
ops.impl("combine_prefill", torch::kPrivateUse1,
&vllm_ascend::combine_prefill);
}