131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
|
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
|
|
|
|
|
|
@CustomOp.register("quant_fusion_rms_norm")
|
|
class QuantFusionRMSNorm(RMSNorm):
|
|
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
|
super().__init__(hidden_size, variance_epsilon)
|
|
assert not isinstance(
|
|
proj.quant_method, UnquantizedLinearMethod
|
|
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
|
proj.quant_method.skip_quant_input = True
|
|
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
|
quant_scale = proj.smooth.data
|
|
else:
|
|
quant_scale = proj.scale_to_int.data
|
|
self.dynamic_quant = dynamic_quant
|
|
self.quant_scale = torch.nn.Parameter(quant_scale)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
|
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
|
return mlu_ops.fused_rms_norm(
|
|
x,
|
|
residual,
|
|
self.weight.data,
|
|
None,
|
|
None,
|
|
self.variance_epsilon,
|
|
False,
|
|
self.quant_scale.data,
|
|
self.dynamic_quant,
|
|
)
|
|
|
|
|
|
@CustomOp.register("quant_fusion_layer_norm")
|
|
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
|
|
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
|
super().__init__(hidden_size, variance_epsilon)
|
|
assert not isinstance(
|
|
proj.quant_method, UnquantizedLinearMethod
|
|
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
|
proj.quant_method.skip_quant_input = True
|
|
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
|
quant_scale = proj.smooth.data
|
|
else:
|
|
quant_scale = proj.scale_to_int.data
|
|
self.dynamic_quant = dynamic_quant
|
|
self.quant_scale = torch.nn.Parameter(quant_scale)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
|
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
|
bias = None if self.bias is None else self.bias.data
|
|
return mlu_ops.fused_layer_norm(
|
|
x,
|
|
residual,
|
|
self.weight.data,
|
|
bias,
|
|
None,
|
|
self.eps,
|
|
False,
|
|
self.quant_scale.data,
|
|
self.dynamic_quant,
|
|
)
|
|
|
|
|
|
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor | None = None,
|
|
out: torch.Tensor | None = None,
|
|
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
|
|
|
org_shape = x.shape
|
|
x = x.reshape(-1, self.weight.data.shape[0])
|
|
if out is not None:
|
|
out = out.view(-1, self.weight.data.shape[0])
|
|
if residual is not None:
|
|
residual = residual.view(-1, self.weight.data.shape[0])
|
|
x = mlu_ops.fused_rms_norm(
|
|
x,
|
|
residual,
|
|
self.weight.data,
|
|
None,
|
|
None,
|
|
self.variance_epsilon,
|
|
True,
|
|
out=out,
|
|
)
|
|
else:
|
|
x = mlu_ops.fused_rms_norm(
|
|
x,
|
|
residual,
|
|
self.weight.data,
|
|
None,
|
|
None,
|
|
self.variance_epsilon,
|
|
False,
|
|
out=out,
|
|
)
|
|
|
|
if out is not None:
|
|
return x
|
|
|
|
if residual is None:
|
|
assert isinstance(x, torch.Tensor)
|
|
return x.view(org_shape)
|
|
|
|
assert isinstance(x, tuple)
|
|
assert len(x) == 2
|
|
return x[0].view(org_shape), x[1].view(org_shape)
|
|
|
|
MluHijackObject.apply_hijack(
|
|
RMSNorm,
|
|
RMSNorm.forward_oot,
|
|
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
|
|
)
|