[Kernel] Add moe_gating_top_k operator support for Ascend NPU (#5579)
### What this PR does / why we need it?
1.replace moe_gating_top_k from torch_npu with custom op
2.enable the renorm function of moe_gating_top_k in softmax scenerio
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
No need test
- vLLM version: v0.13.0
- vLLM main:
7157596103
---------
Signed-off-by: ZCG12345 <2097562023@qq.com>
This commit is contained in:
@@ -366,7 +366,43 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
|
||||
at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat));
|
||||
return {expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale};
|
||||
}
|
||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
|
||||
const at::Tensor& x,
|
||||
int64_t k,
|
||||
int64_t k_group,
|
||||
int64_t group_count,
|
||||
int64_t group_select_mode,
|
||||
int64_t renorm,
|
||||
int64_t norm_type,
|
||||
bool out_flag,
|
||||
double routed_scaling_factor,
|
||||
double eps,
|
||||
const c10::optional<at::Tensor>& bias_opt
|
||||
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
|
||||
TORCH_CHECK(
|
||||
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
|
||||
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
|
||||
x.scalar_type());
|
||||
|
||||
auto x_size = x.sizes();
|
||||
auto rows = x_size[0];
|
||||
auto expert_num = x_size[1];
|
||||
const at::Tensor &bias = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
|
||||
if (bias.defined()) {
|
||||
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
|
||||
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
|
||||
auto bias_size = bias.sizes();
|
||||
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
|
||||
}
|
||||
at::Tensor y = at::empty({rows, k}, x.options());
|
||||
at::Tensor expert_idx = at::empty({rows, k}, x.options().dtype(at::kInt));
|
||||
at::Tensor out = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -374,6 +410,7 @@ namespace {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
@@ -402,5 +439,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
|
||||
// moe_init_routing_custom
|
||||
ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta);
|
||||
// Moe_gating_top_k
|
||||
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user