first commit
This commit is contained in:
67
vllm_br/model_executor/layers/layernorm.py
Normal file
67
vllm_br/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,67 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch import Tensor, nn
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.weight.data.dtype == torch.bfloat16:
|
||||
self.weight.data = self.weight.data.to(torch.float32)
|
||||
|
||||
if residual is not None:
|
||||
y_supa, add_out_supa = torch_br.supa_add_rmsnorm_infer( # type: ignore
|
||||
x, residual, self.weight.data, self.variance_epsilon)
|
||||
return y_supa, add_out_supa
|
||||
else:
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
if len(x.shape) == 4:
|
||||
x = x.squeeze(0)
|
||||
|
||||
x = torch_br.supa_rmsnorm_infer(
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon # type: ignore
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(RMSNorm)
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@patch_to(nn.LayerNorm)
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if os.environ.get("USE_BR_FUSED_LAYERNORM",
|
||||
'False').lower() not in {'false', '0', ''}:
|
||||
return torch_br.fused_layernorm(input, self.weight, self.bias,
|
||||
self.eps)
|
||||
else:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape,
|
||||
self.weight, self.bias, self.eps)
|
||||
Reference in New Issue
Block a user