################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import os from typing import Optional, Tuple, Union import torch import torch_br from fastcore.basics import patch_to from torch import Tensor, nn from vllm.model_executor.layers.layernorm import RMSNorm @patch_to(RMSNorm) def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.weight.data.dtype == torch.bfloat16: self.weight.data = self.weight.data.to(torch.float32) if residual is not None: y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore x, residual, self.weight.data, self.variance_epsilon) return y_supa, add_out_supa else: if len(x.shape) == 2: x = x.unsqueeze(0) if len(x.shape) == 4: x = x.squeeze(0) x = torch_br.supa_rmsnorm_infer( x, self.weight.data, self.variance_epsilon # type: ignore ) return x @patch_to(RMSNorm) def enabled(cls) -> bool: return True @patch_to(nn.LayerNorm) def forward(self, input: Tensor) -> Tensor: if os.environ.get("USE_BR_FUSED_LAYERNORM", 'False').lower() not in {'false', '0', ''}: return torch_br.fused_layernorm(input, self.weight, self.bias, self.eps) else: return nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)