diff --git a/vllm_ascend/torchair/ops/torchair_layernorm.py b/vllm_ascend/torchair/ops/torchair_layernorm.py index d90f889..2f52dab 100644 --- a/vllm_ascend/torchair/ops/torchair_layernorm.py +++ b/vllm_ascend/torchair/ops/torchair_layernorm.py @@ -18,6 +18,32 @@ from typing import Optional, Tuple, Union import torch +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.layernorm import RMSNorm + +from vllm_ascend.utils import version_check + +_original_re_init = RMSNorm.__init__ + + +def torchair_rmsnorm_init_( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, +) -> None: + _original_re_init(self, hidden_size, eps, var_hidden_size, has_weight, + dtype) + vllm_config = get_current_vllm_config() + self.bias = None + self.torch_npu_check = version_check() + # quantization with anti_method m4 will generate none-zero norm bias + if self.torch_npu_check and vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) def torchair_rmsnorm_forward_oot( @@ -33,6 +59,7 @@ def torchair_rmsnorm_forward_oot( """ import torch_npu + torch_npu_check = version_check() from vllm_ascend.utils import is_310p if residual is not None: @@ -45,7 +72,11 @@ def torchair_rmsnorm_forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + if torch_npu_check and self.bias is not None: + x.add_(self.bias) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + if torch_npu_check and self.bias is not None: + x.add_(self.bias) return x diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 211d738..af61c65 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -229,10 +229,12 @@ def torchair_ops_patch(): AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] + AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign] AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] + AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign] def super_kernel(prefix: str, option: str, enabled: bool = True): - return _super_kernel(prefix, option) if enabled else nullcontext() \ No newline at end of file + return _super_kernel(prefix, option) if enabled else nullcontext()