Add GemmaRmsNorm ACLGraph Support (#6473)

### What this PR does / why we need it?
1. New Custom NPU Operation: Introduced npu_gemma_rms_norm in
csrc/torch_binding.cpp to provide optimized Gemma RMS Normalization
support for Ascend NPUs. This function includes logic to handle dynamic
shapes for the gamma tensor.
2. PyTorch Operator Registration: The new npu_gemma_rms_norm operation
has been registered with the PyTorch custom operator library, making it
accessible from Python.
Meta-Implementation for ACLGraph: A corresponding meta-implementation,
npu_gemma_rms_norm_meta, was added in csrc/torch_binding_meta.cpp. This
is crucial for symbolic tracing and allowing the custom kernel to be
captured and optimized by ACLGraph.
3. Python Frontend Integration: The vllm_ascend/ops/layernorm.py file
was updated to utilize the newly added
torch.ops._C_ascend.npu_gemma_rms_norm for Gemma RMS Normalization,
replacing the generic torch_npu.npu_rms_norm
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com>
This commit is contained in:
LeeWenquan
2026-03-05 16:15:07 +08:00
committed by GitHub
parent 5a3744c542
commit 3047b724b3
3 changed files with 66 additions and 3 deletions

View File

@@ -553,6 +553,34 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> dispatch_prefill(
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
}
std::tuple<at::Tensor, at::Tensor> npu_gemma_rms_norm(
const at::Tensor& x,
const at::Tensor& gamma,
double epsilon)
{
int64_t dim_x = x.dim();
int64_t dim_gamma = gamma.dim();
int64_t diff = dim_x - dim_gamma;
std::vector<int64_t> new_shape;
at::Tensor rstd;
if (diff > 0) {
new_shape.reserve(dim_x);
auto x_sizes = x.sizes();
for (int64_t i = 0; i < diff; ++i) {
new_shape.push_back(x_sizes[i]);
}
for (int64_t i = 0; i < dim_gamma; ++i) {
new_shape.push_back(1);
}
} else {
new_shape.assign(dim_x, 1);
}
rstd = at::empty(new_shape, x.options().dtype(at::kFloat));
at::Tensor y = at::empty(x.sizes(), x.options());
EXEC_NPU_CMD(aclnnGemmaRmsNorm, x, gamma, epsilon, y, rstd);
return std::tuple<at::Tensor, at::Tensor>(y, rstd);
}
void transpose_kv_cache_by_block(
const at::TensorList &kCache,
const at::TensorList &vCache,
@@ -575,6 +603,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
{
// vLLM-Ascend custom ops
// Gemma RmsNorm
ops.def(
"npu_gemma_rms_norm(Tensor x, "
"Tensor gamma, "
"float epsilon=1e-6)"
"-> (Tensor y ,Tensor rstd)"
);
ops.impl("npu_gemma_rms_norm", torch::kPrivateUse1, &vllm_ascend::npu_gemma_rms_norm);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);