Add GemmaRmsNorm ACLGraph Support (#6473)

### What this PR does / why we need it?
1. New Custom NPU Operation: Introduced npu_gemma_rms_norm in
csrc/torch_binding.cpp to provide optimized Gemma RMS Normalization
support for Ascend NPUs. This function includes logic to handle dynamic
shapes for the gamma tensor.
2. PyTorch Operator Registration: The new npu_gemma_rms_norm operation
has been registered with the PyTorch custom operator library, making it
accessible from Python.
Meta-Implementation for ACLGraph: A corresponding meta-implementation,
npu_gemma_rms_norm_meta, was added in csrc/torch_binding_meta.cpp. This
is crucial for symbolic tracing and allowing the custom kernel to be
captured and optimized by ACLGraph.
3. Python Frontend Integration: The vllm_ascend/ops/layernorm.py file
was updated to utilize the newly added
torch.ops._C_ascend.npu_gemma_rms_norm for Gemma RMS Normalization,
replacing the generic torch_npu.npu_rms_norm
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com>
This commit is contained in:
LeeWenquan
2026-03-05 16:15:07 +08:00
committed by GitHub
parent 5a3744c542
commit 3047b724b3
3 changed files with 66 additions and 3 deletions

View File

@@ -418,6 +418,33 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
}
std::tuple<at::Tensor, at::Tensor> npu_gemma_rms_norm_meta(
const at::Tensor& x,
const at::Tensor& gamma,
double epsilon)
{
int64_t dim_x = x.dim();
int64_t dim_gamma = gamma.dim();
int64_t diff = dim_x - dim_gamma;
c10::SymDimVector new_shape;
at::Tensor rstd;
if (diff > 0) {
new_shape.reserve(dim_x);
auto x_sizes = x.sym_sizes();
for (int64_t i = 0; i < diff; ++i) {
new_shape.push_back(x_sizes[i]);
}
for (int64_t i = 0; i < dim_gamma; ++i) {
new_shape.push_back(c10::SymInt(1));
}
} else {
new_shape.assign(dim_x, c10::SymInt(1));
}
rstd = at::empty_symint(new_shape, x.options().dtype(at::kFloat));
at::Tensor y = at::empty_symint(x.sym_sizes(), x.options());
return std::tuple<at::Tensor, at::Tensor>(y, rstd);
}
void transpose_kv_cache_by_block_meta(
const at::TensorList &k_cache,
const at::TensorList &v_cache,
@@ -430,7 +457,6 @@ void transpose_kv_cache_by_block_meta(
{
return;
}
} // namespace meta
} // namespace vllm_ascend
@@ -438,7 +464,8 @@ namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
//Gemma rmsnorm meta implementation
ops.impl("npu_gemma_rms_norm", &vllm_ascend::meta::npu_gemma_rms_norm_meta);
// Masked input and mask meta implementation
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
// Bgmv expand