From c3fee66806f252476796389ea73d13a8aca60146 Mon Sep 17 00:00:00 2001 From: socrahow Date: Sun, 28 Sep 2025 21:19:10 +0800 Subject: [PATCH] [Model] Optimizing gemma3 model's GemmaRMSNorm function (#3151) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Before optimizing,the rmsnorm time in one decoding is 531.5us. After optimizing,the rmsnorm time in one decoding is 105us. I closed the previous PR(https://github.com/vllm-project/vllm-ascend/pull/2456) by mistake and resubmitted it now ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b1068903fdca26cf6b4a1a51a32c3365ce3ac636 --------- Signed-off-by: socrahow --- vllm_ascend/ops/layernorm.py | 29 ++++++++++++++++++++++++++++- vllm_ascend/utils.py | 4 +++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index da48362..3dfca53 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -19,7 +19,7 @@ from typing import Optional, Tuple, Union, cast import torch from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm def _addrmsnorm_forward_oot( @@ -130,3 +130,30 @@ class AscendQuantRMSNorm(AscendRMSNorm): x, residual = super().forward_oot(x, residual) return x.add_(self.bias), residual return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) + + +class AscendGemmaRMSNorm(GemmaRMSNorm): + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + from vllm_ascend.utils import is_310p + if residual is not None: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, 1.0 + self.weight, self.variance_epsilon) + return x, residual + + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, + self.variance_epsilon) + return x diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 7b430ab..6107e25 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -505,7 +505,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) - from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm + from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, + AscendQuantRMSNorm, AscendRMSNorm) from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, @@ -530,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "ParallelLMHead": AscendParallelLMHead, "LogitsProcessor": AscendLogitsProcessor, "RMSNorm": AscendRMSNorm, + "GemmaRMSNorm": AscendGemmaRMSNorm, "FusedMoE": AscendFusedMoE, "SharedFusedMoE": AscendSharedFusedMoE, "MultiHeadLatentAttention": AscendMultiHeadLatentAttention,