[Model] Support DeepSeek-V4
This commit is contained in:
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
130
vllm_mlu/model_executor/layers/layernorm.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.models.layer_utils import is_per_token_smoothquant
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_rms_norm")
|
||||
class QuantFusionRMSNorm(RMSNorm):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
return mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("quant_fusion_layer_norm")
|
||||
class QuantFusionLayerNorm(torch.nn.LayerNorm, CustomOp):
|
||||
def __init__(self, hidden_size: int, variance_epsilon: float, proj: LinearBase):
|
||||
super().__init__(hidden_size, variance_epsilon)
|
||||
assert not isinstance(
|
||||
proj.quant_method, UnquantizedLinearMethod
|
||||
), f"UnquantizedLinearMethod of {proj.__class__.__name__} is not supported"
|
||||
proj.quant_method.skip_quant_input = True
|
||||
if dynamic_quant := is_per_token_smoothquant(proj.quant_method.quant_config):
|
||||
quant_scale = proj.smooth.data
|
||||
else:
|
||||
quant_scale = proj.scale_to_int.data
|
||||
self.dynamic_quant = dynamic_quant
|
||||
self.quant_scale = torch.nn.Parameter(quant_scale)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor | None = None
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
bias = None if self.bias is None else self.bias.data
|
||||
return mlu_ops.fused_layer_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
bias,
|
||||
None,
|
||||
self.eps,
|
||||
False,
|
||||
self.quant_scale.data,
|
||||
self.dynamic_quant,
|
||||
)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__layernorm__RMSNorm__forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor] | None:
|
||||
|
||||
org_shape = x.shape
|
||||
x = x.reshape(-1, self.weight.data.shape[0])
|
||||
if out is not None:
|
||||
out = out.view(-1, self.weight.data.shape[0])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
out=out,
|
||||
)
|
||||
else:
|
||||
x = mlu_ops.fused_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
None,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
False,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
return x
|
||||
|
||||
if residual is None:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
return x.view(org_shape)
|
||||
|
||||
assert isinstance(x, tuple)
|
||||
assert len(x) == 2
|
||||
return x[0].view(org_shape), x[1].view(org_shape)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RMSNorm,
|
||||
RMSNorm.forward_oot,
|
||||
vllm__model_executor__layers__layernorm__RMSNorm__forward_oot,
|
||||
)
|
||||
Reference in New Issue
Block a user