[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)

What this PR does / why we need it?
Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce,
Add, and RMSNorm

Does this PR introduce _any_ user-facing change?
No

How was this patch tested?

vLLM version: v0.11.2
vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: tongrunze <t00574058@china.huawei.com>
Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
Trunrain
2025-12-10 09:05:33 +08:00
committed by GitHub
parent f404c9af7f
commit ba9cda9dfd
16 changed files with 2854 additions and 1 deletions

View File

@@ -807,6 +807,36 @@ at::Tensor npu_sparse_flash_attention(
output);
return output;
}
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm(
const at::Tensor &x1,
const at::Tensor &x2,
const at::Tensor &residual,
const at::Tensor &gamma,
c10::string_view group_tp,
int64_t tp_rank_size,
int64_t tp_rank_id,
double epsilon,
bool is_trans_b,
bool is_gather_add_out)
{
at::Tensor output = at::empty_like(residual);
at::Tensor add_out = at::empty_like(residual);
std::string group_tp_str(group_tp);
char *group_tp_ptr = group_tp_str.data();
float epsilon_f = static_cast<float>(epsilon);
EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm,
// input
x1, x2, residual, gamma,
// attr
group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out,
// output
output, add_out);
return {output, add_out};
}
} // namespace vllm_ascend
@@ -921,4 +951,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" int max_output_size, Tensor! out) -> Tensor"
);
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)");
ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm);
}