From 3047b724b39127fdb60f6519c31883908899d2bb Mon Sep 17 00:00:00 2001 From: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:15:07 +0800 Subject: [PATCH] Add GemmaRmsNorm ACLGraph Support (#6473) ### What this PR does / why we need it? 1. New Custom NPU Operation: Introduced npu_gemma_rms_norm in csrc/torch_binding.cpp to provide optimized Gemma RMS Normalization support for Ascend NPUs. This function includes logic to handle dynamic shapes for the gamma tensor. 2. PyTorch Operator Registration: The new npu_gemma_rms_norm operation has been registered with the PyTorch custom operator library, making it accessible from Python. Meta-Implementation for ACLGraph: A corresponding meta-implementation, npu_gemma_rms_norm_meta, was added in csrc/torch_binding_meta.cpp. This is crucial for symbolic tracing and allowing the custom kernel to be captured and optimized by ACLGraph. 3. Python Frontend Integration: The vllm_ascend/ops/layernorm.py file was updated to utilize the newly added torch.ops._C_ascend.npu_gemma_rms_norm for Gemma RMS Normalization, replacing the generic torch_npu.npu_rms_norm ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: SunnyLee219 <3294305115@qq.com> Signed-off-by: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com> --- csrc/torch_binding.cpp | 36 ++++++++++++++++++++++++++++++++++++ csrc/torch_binding_meta.cpp | 31 +++++++++++++++++++++++++++++-- vllm_ascend/ops/layernorm.py | 2 +- 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 77eaf4d5..ae60ce05 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -553,6 +553,34 @@ std::tuple dispatch_prefill( return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert}; } +std::tuple npu_gemma_rms_norm( + const at::Tensor& x, + const at::Tensor& gamma, + double epsilon) +{ + int64_t dim_x = x.dim(); + int64_t dim_gamma = gamma.dim(); + int64_t diff = dim_x - dim_gamma; + std::vector new_shape; + at::Tensor rstd; + if (diff > 0) { + new_shape.reserve(dim_x); + auto x_sizes = x.sizes(); + for (int64_t i = 0; i < diff; ++i) { + new_shape.push_back(x_sizes[i]); + } + for (int64_t i = 0; i < dim_gamma; ++i) { + new_shape.push_back(1); + } + } else { + new_shape.assign(dim_x, 1); + } + rstd = at::empty(new_shape, x.options().dtype(at::kFloat)); + at::Tensor y = at::empty(x.sizes(), x.options()); + EXEC_NPU_CMD(aclnnGemmaRmsNorm, x, gamma, epsilon, y, rstd); + return std::tuple(y, rstd); +} + void transpose_kv_cache_by_block( const at::TensorList &kCache, const at::TensorList &vCache, @@ -575,6 +603,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) { // vLLM-Ascend custom ops + // Gemma RmsNorm + ops.def( + "npu_gemma_rms_norm(Tensor x, " + "Tensor gamma, " + "float epsilon=1e-6)" + "-> (Tensor y ,Tensor rstd)" + ); + ops.impl("npu_gemma_rms_norm", torch::kPrivateUse1, &vllm_ascend::npu_gemma_rms_norm); ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 6e3f66ec..25c2cf85 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -418,6 +418,33 @@ std::tuple npu_add_rms_norm_bias_meta( return std::tuple(y, rstd, x); } +std::tuple npu_gemma_rms_norm_meta( + const at::Tensor& x, + const at::Tensor& gamma, + double epsilon) +{ + int64_t dim_x = x.dim(); + int64_t dim_gamma = gamma.dim(); + int64_t diff = dim_x - dim_gamma; + c10::SymDimVector new_shape; + at::Tensor rstd; + if (diff > 0) { + new_shape.reserve(dim_x); + auto x_sizes = x.sym_sizes(); + for (int64_t i = 0; i < diff; ++i) { + new_shape.push_back(x_sizes[i]); + } + for (int64_t i = 0; i < dim_gamma; ++i) { + new_shape.push_back(c10::SymInt(1)); + } + } else { + new_shape.assign(dim_x, c10::SymInt(1)); + } + rstd = at::empty_symint(new_shape, x.options().dtype(at::kFloat)); + at::Tensor y = at::empty_symint(x.sym_sizes(), x.options()); + return std::tuple(y, rstd); +} + void transpose_kv_cache_by_block_meta( const at::TensorList &k_cache, const at::TensorList &v_cache, @@ -430,7 +457,6 @@ void transpose_kv_cache_by_block_meta( { return; } - } // namespace meta } // namespace vllm_ascend @@ -438,7 +464,8 @@ namespace { // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { - + //Gemma rmsnorm meta implementation + ops.impl("npu_gemma_rms_norm", &vllm_ascend::meta::npu_gemma_rms_norm_meta); // Masked input and mask meta implementation ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta); // Bgmv expand diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 17214afb..11fb7538 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -87,7 +87,7 @@ class AscendGemmaRMSNorm(GemmaRMSNorm): x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon) return x, residual - x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) + x, _ = torch.ops._C_ascend.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) return x