[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)

What this PR does / why we need it?
Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce,
Add, and RMSNorm

Does this PR introduce _any_ user-facing change?
No

How was this patch tested?

vLLM version: v0.11.2
vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: tongrunze <t00574058@china.huawei.com>
Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
Trunrain
2025-12-10 09:05:33 +08:00
committed by GitHub
parent f404c9af7f
commit ba9cda9dfd
16 changed files with 2854 additions and 1 deletions

View File

@@ -264,6 +264,23 @@ at::Tensor npu_sparse_flash_attention_meta(
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
return output;
}
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm_meta(
const at::Tensor &x1,
const at::Tensor &x2,
const at::Tensor &residual,
const at::Tensor &gamma,
c10::string_view group_tp,
int64_t tp_rank_size,
int64_t tp_rank_id,
double epsilon,
bool is_trans_b,
bool is_gather_add_out)
{
at::Tensor output = at::empty_like(residual);
at::Tensor add_out = at::empty_like(residual);
return {output, add_out};
}
} // namespace meta
} // namespace vllm_ascend
@@ -296,5 +313,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
// MoE dispatch-ffn-combine
ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta);
// matmul allreduce add rmsnorm
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
}
}