[Model] Optimizing gemma3 model's GemmaRMSNorm function (#3151)
### What this PR does / why we need it?
Before optimizing,the rmsnorm time in one decoding is 531.5us. After
optimizing,the rmsnorm time in one decoding is 105us.
I closed the previous
PR(https://github.com/vllm-project/vllm-ascend/pull/2456) by mistake and
resubmitted it now
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
b1068903fd
---------
Signed-off-by: socrahow <suzihao4@h-partners.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
|
||||
def _addrmsnorm_forward_oot(
|
||||
@@ -130,3 +130,30 @@ class AscendQuantRMSNorm(AscendRMSNorm):
|
||||
x, residual = super().forward_oot(x, residual)
|
||||
return x.add_(self.bias), residual
|
||||
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
||||
|
||||
|
||||
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
|
||||
self.variance_epsilon)
|
||||
return x
|
||||
|
||||
@@ -505,7 +505,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
|
||||
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
|
||||
AscendQuantRMSNorm, AscendRMSNorm)
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
@@ -530,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"ParallelLMHead": AscendParallelLMHead,
|
||||
"LogitsProcessor": AscendLogitsProcessor,
|
||||
"RMSNorm": AscendRMSNorm,
|
||||
"GemmaRMSNorm": AscendGemmaRMSNorm,
|
||||
"FusedMoE": AscendFusedMoE,
|
||||
"SharedFusedMoE": AscendSharedFusedMoE,
|
||||
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
|
||||
|
||||
Reference in New Issue
Block a user