From 7669963c272b21a4d0c034fc5cbdac22bf815a77 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:53:28 +0800 Subject: [PATCH] [Perf] Optimize bias handling in AscendRMSNorm (#7226) ### What this PR does / why we need it? This PR optimizes bias handling in `AscendRMSNorm` without changing the intended functional behavior. In the current implementation, bias may be initialized for `AscendRMSNorm` based on configuration-level detection, even though some norm layers never actually load a bias weight. This can cause the inference path to enter the bias branch and execute an unnecessary `add_` operator. To improve this, this PR introduces a loader-based flag to record whether the bias has actually been loaded. The bias addition is then executed only when the bias is truly present. This optimization reduces redundant computation in inference and makes the bias application logic better aligned with the actual model weights. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/ops/layernorm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 11fb7538..d998ddab 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -37,11 +37,28 @@ class AscendRMSNorm(RMSNorm): super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) vllm_config = get_current_vllm_config() self.bias = None + self.bias_loaded = False + # quantization with anti_method m4 will generate none-zero norm bias if vllm_config.quant_config is not None and any( "norm.bias" in name for name in vllm_config.quant_config.quant_description ): self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) + self.bias.weight_loader = self._bias_weight_loader + + def _bias_weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + self.bias_loaded = True def forward_oot( self, @@ -62,7 +79,7 @@ class AscendRMSNorm(RMSNorm): return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - if self.bias is not None: + if self.bias_loaded: x.add_(self.bias) weight_prefetch_method = get_weight_prefetch_method()