[Model] Support DeepSeek-V4
This commit is contained in:
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("gptq_mlu")
|
||||
class GPTQMluConfig(QuantizationConfig):
|
||||
"""Config class for GPTQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
(4, False): scalar_types.uint4b8,
|
||||
(8, False): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
self.support_scale_zeros = False
|
||||
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is "
|
||||
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size, has_zp=False)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
class GPTQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
|
||||
-1) and (not self.quant_config.desc_act):
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.device = layer.qweight.data.device
|
||||
packed_qweight, scale_zeros = self.extract_autogptq(layer)
|
||||
if self.quant_config.use_native:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.use_native:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def extract_autogptq(self, layer: torch.nn.Module):
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
|
||||
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
|
||||
|
||||
if self.quant_config.use_native:
|
||||
if self.quant_config.desc_act:
|
||||
scales = torch.index_select(scales, 0, layer.g_idx)
|
||||
if izeros is not None:
|
||||
izeros = torch.index_select(izeros, 0, layer.g_idx)
|
||||
else:
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
|
||||
if izeros is not None:
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
# for is_sym is true now, so make iweight to sign value and ignore qzeros
|
||||
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
|
||||
shifts.unsqueeze(-1),
|
||||
).to(dtype)
|
||||
weight = torch.bitwise_and(weight, (2**bits) - 1)
|
||||
weight = weight.reshape(-1, weight.shape[-1])
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
shifts.unsqueeze(0),
|
||||
).to(dtype)
|
||||
|
||||
zeros = zeros + 1
|
||||
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
zeros = zeros.reshape(qzeros.shape[0], -1)
|
||||
|
||||
return zeros
|
||||
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
Reference in New Issue
Block a user