moe_gating_top_k (#5271)

1. What this PR does / why we need it?
This PR supports the moe_gating_top_k operator, which enables
post-positioned renormalization (renorm) on the basis of softmax.
2. Does this PR introduce any user-facing change?
No user-facing changes are required.
3. How was this patch tested?
This patch was tested with the test_npu_moe_gating_top_k test case.
vLLM version: release/v0.13.0
vLLM main:
ad32e3e19c

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
ZCG12345
2025-12-30 09:28:01 +08:00
committed by GitHub
parent 15d73f248e
commit 45c3c279e2
34 changed files with 4791 additions and 22 deletions

View File

@@ -1118,6 +1118,60 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons
return combined_x;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
const at::Tensor& x,
int64_t k,
int64_t kGroup,
int64_t groupCount,
int64_t groupSelectMode,
int64_t renorm,
int64_t normType,
bool outFlag,
double routedScalingFactor,
double eps,
const c10::optional<at::Tensor>& biasOptional
)
{
TORCH_CHECK(x.dim() == 2, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
auto x_size = x.sizes();
auto rows = x_size[0];
auto expert_num = x_size[1];
const at::Tensor &bias = c10::value_or_else(biasOptional, [] { return at::Tensor(); });
if (bias.defined()) {
TORCH_CHECK(x.scalar_type() == bias.scalar_type(), "The dtype of x and bias should be same");
TORCH_CHECK(bias.dim() == 1, "The bias should be 1D");
auto bias_size = bias.sizes();
TORCH_CHECK(bias_size[0] == expert_num, "The bias first dim should be same as x second dim");
}
at::Tensor yOut = at::empty({rows, k}, x.options());
at::Tensor expertIdxOut = at::empty({rows, k}, x.options().dtype(at::kInt));
at::Tensor outOut = at::empty({rows, expert_num}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(aclnnMoeGatingTopK,
x, // input_x
biasOptional,
k, // k
kGroup, // k_group
groupCount, // group_count
groupSelectMode, // group_select_mode
renorm, // renorm
normType, // norm_type
outFlag, // out_flag
routedScalingFactor, // routed_scaling_factor
eps, // eps
yOut, // input_y (注意:这里应该是 yOut)
expertIdxOut, // output
outOut
);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(outOut,expertIdxOut, yOut);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom(
const at::Tensor &x, const at::Tensor &expert_idx,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
@@ -1221,8 +1275,25 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
{
{
// vLLM-Ascend custom ops
ops.def(
"moe_gating_top_k(Tensor x, "
"int k, "
"int kGroup, "
"int groupCount, "
"int groupSelectMode, "
"int renorm, "
"int normType, "
"bool outFlag, "
"float routedScalingFactor, "
"float eps,"
"Tensor? biasOptional=None)"
"-> (Tensor outOut,Tensor expertIdxOut, Tensor yOut)"
);
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
//Moe_gating
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);