[Kernel] add custom op MatmulAllreduceAddRmsnorm (#4606)
What this PR does / why we need it? Optimization of the fused operator for Qwen3 32B: Matmul, AllReduce, Add, and RMSNorm Does this PR introduce _any_ user-facing change? No How was this patch tested? vLLM version: v0.11.2 vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: tongrunze <t00574058@china.huawei.com> Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
@@ -264,6 +264,23 @@ at::Tensor npu_sparse_flash_attention_meta(
|
||||
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
|
||||
return output;
|
||||
}
|
||||
std::tuple<at::Tensor, at::Tensor> matmul_allreduce_add_rmsnorm_meta(
|
||||
const at::Tensor &x1,
|
||||
const at::Tensor &x2,
|
||||
const at::Tensor &residual,
|
||||
const at::Tensor &gamma,
|
||||
c10::string_view group_tp,
|
||||
int64_t tp_rank_size,
|
||||
int64_t tp_rank_id,
|
||||
double epsilon,
|
||||
bool is_trans_b,
|
||||
bool is_gather_add_out)
|
||||
{
|
||||
at::Tensor output = at::empty_like(residual);
|
||||
at::Tensor add_out = at::empty_like(residual);
|
||||
|
||||
return {output, add_out};
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@@ -296,5 +313,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta);
|
||||
// MoE dispatch-ffn-combine
|
||||
ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta);
|
||||
// matmul allreduce add rmsnorm
|
||||
ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user