[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)
What this PR does / why we need it? Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce, Add, and RMSNorm Does this PR introduce _any_ user-facing change? No How was this patch tested? vLLM version: v0.11.2 vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: tongrunze <t00574058@china.huawei.com> Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
@@ -807,6 +807,36 @@ at::Tensor npu_sparse_flash_attention(
|
||||
output);
|
||||
return output;
|
||||
}
|
||||
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm(
|
||||
const at::Tensor &x1,
|
||||
const at::Tensor &x2,
|
||||
const at::Tensor &residual,
|
||||
const at::Tensor &gamma,
|
||||
c10::string_view group_tp,
|
||||
int64_t tp_rank_size,
|
||||
int64_t tp_rank_id,
|
||||
double epsilon,
|
||||
bool is_trans_b,
|
||||
bool is_gather_add_out)
|
||||
{
|
||||
at::Tensor output = at::empty_like(residual);
|
||||
at::Tensor add_out = at::empty_like(residual);
|
||||
|
||||
std::string group_tp_str(group_tp);
|
||||
|
||||
char *group_tp_ptr = group_tp_str.data();
|
||||
|
||||
float epsilon_f = static_cast<float>(epsilon);
|
||||
EXEC_NPU_CMD(aclnnMatmulAllreduceAddRmsnorm,
|
||||
// input
|
||||
x1, x2, residual, gamma,
|
||||
// attr
|
||||
group_tp_ptr, tp_rank_size, tp_rank_id, epsilon_f, is_trans_b, is_gather_add_out,
|
||||
// output
|
||||
output, add_out);
|
||||
|
||||
return {output, add_out};
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -921,4 +951,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
|
||||
ops.def("matmul_allreduce_add_rmsnorm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
|
||||
str groupTp, int tpRankSize, int tpRankId, float epsilon, bool isTransB, bool isGatherAddOut) -> (Tensor output, Tensor add_out)");
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", torch::kPrivateUse1, &vllm_ascend::matmul_allreduce_add_rmsnorm);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user