242 lines
9.2 KiB
Python
242 lines
9.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
|
PackedvLLMParameter)
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
class AWQConfig(QuantizationConfig):
|
|
"""Config class for AWQ.
|
|
|
|
Reference: https://arxiv.org/abs/2306.00978
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
zero_point: bool,
|
|
modules_to_not_convert: Optional[list[str]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight_bits = weight_bits
|
|
self.group_size = group_size
|
|
self.zero_point = zero_point
|
|
self.modules_to_not_convert = modules_to_not_convert or []
|
|
|
|
if self.weight_bits != 4:
|
|
raise ValueError(
|
|
"Currently, only 4-bit weight quantization is supported for "
|
|
f"AWQ, but got {self.weight_bits} bits.")
|
|
self.pack_factor = 32 // self.weight_bits
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
|
f"group_size={self.group_size}, "
|
|
f"zero_point={self.zero_point}, "
|
|
f"modules_to_not_convert={self.modules_to_not_convert})")
|
|
|
|
def get_name(self) -> QuantizationMethods:
|
|
return "awq"
|
|
|
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
return [torch.half, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
# The AWQ kernel only supports Turing or newer GPUs.
|
|
return 75
|
|
|
|
@staticmethod
|
|
def get_config_filenames() -> list[str]:
|
|
return [
|
|
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
|
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
|
"quantize_config.json",
|
|
]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
|
modules_to_not_convert = cls.get_from_keys_or(
|
|
config, ["modules_to_not_convert"], None)
|
|
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["LinearMethodBase"]:
|
|
if isinstance(layer, LinearBase):
|
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
|
return UnquantizedLinearMethod()
|
|
return AWQLinearMethod(self)
|
|
return None
|
|
|
|
|
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
|
|
|
def _apply_awq_fake(x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
qzeros: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
pack_factor: int,
|
|
group_size: int) -> torch.Tensor:
|
|
out_shape = ()
|
|
if group_size % 32:
|
|
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
|
else:
|
|
out_shape = (x.shape[:-1] + (qweight.shape[0], ))
|
|
return torch.empty(out_shape, dtype=x.dtype, device=x.device)
|
|
|
|
def _apply_awq(x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
qzeros: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
pack_factor: int,
|
|
group_size: int) -> torch.Tensor:
|
|
out_shape = ()
|
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
|
out = torch.empty(0)
|
|
# num_tokens >= threshold
|
|
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
|
# if (FP16_MATMUL_HEURISTIC_CONDITION and reshaped_x.dtype == torch.half) or self.quant_config.group_size != 128:
|
|
if group_size % 32:
|
|
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
|
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
|
out = torch.matmul(reshaped_x, out)
|
|
else:
|
|
num_out_channel = qweight.shape[0]
|
|
out_shape = (x.shape[:-1] + (num_out_channel, ))
|
|
temp_space = torch.empty(0, dtype=torch.float32, device=x.device)
|
|
if reshaped_x.dtype == torch.bfloat16:
|
|
temp_space = torch.zeros(reshaped_x.shape[0], num_out_channel,
|
|
dtype=torch.float32, device=x.device)
|
|
out = ops.awq_gemm(reshaped_x, qweight, qzeros, scales,
|
|
pack_factor, temp_space,
|
|
True if reshaped_x.dtype == torch.bfloat16 else False)
|
|
if bias is not None:
|
|
out.add_(bias)
|
|
return out.reshape(out_shape)
|
|
|
|
direct_register_custom_op(
|
|
op_name="_apply_awq",
|
|
op_func=_apply_awq,
|
|
mutates_args=[],
|
|
fake_impl=_apply_awq_fake,
|
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
)
|
|
|
|
class AWQLinearMethod(LinearMethodBase):
|
|
"""Linear method for AWQ.
|
|
|
|
Args:
|
|
quant_config: The AWQ quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AWQConfig):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int], input_size: int,
|
|
output_size: int, params_dtype: torch.dtype,
|
|
**extra_weight_attrs):
|
|
# Normalize group_size
|
|
if self.quant_config.group_size != -1:
|
|
group_size = self.quant_config.group_size
|
|
else:
|
|
group_size = input_size
|
|
|
|
if input_size_per_partition % group_size != 0:
|
|
raise ValueError(
|
|
"The input size is not aligned with the quantized "
|
|
"weight shape. This can be caused by too large "
|
|
"tensor parallel size.")
|
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
|
raise ValueError(
|
|
"The output size is not aligned with the quantized "
|
|
"weight shape. This can be caused by too large "
|
|
"tensor parallel size.")
|
|
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
qweight = PackedvLLMParameter(
|
|
data=torch.empty(
|
|
input_size_per_partition,
|
|
output_size_per_partition // self.quant_config.pack_factor,
|
|
dtype=torch.int32,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
packed_dim=1,
|
|
packed_factor=self.quant_config.pack_factor,
|
|
weight_loader=weight_loader)
|
|
|
|
num_groups = input_size_per_partition // group_size
|
|
|
|
qzeros = PackedvLLMParameter(
|
|
data=torch.empty(
|
|
num_groups,
|
|
output_size_per_partition // self.quant_config.pack_factor,
|
|
dtype=torch.int32,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
packed_dim=1,
|
|
packed_factor=self.quant_config.pack_factor,
|
|
weight_loader=weight_loader)
|
|
|
|
scales = GroupQuantScaleParameter(data=torch.empty(
|
|
num_groups,
|
|
output_size_per_partition,
|
|
dtype=params_dtype,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
weight_loader=weight_loader)
|
|
|
|
layer.register_parameter("qweight", qweight)
|
|
layer.register_parameter("qzeros", qzeros)
|
|
layer.register_parameter("scales", scales)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
|
requires_grad=False)
|
|
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
|
requires_grad=False)
|
|
layer.scales = torch.nn.Parameter(layer.scales.data,
|
|
requires_grad=False)
|
|
# warmup
|
|
if self.quant_config.group_size % 32:
|
|
pass
|
|
else:
|
|
qweight = ops.awq_to_gptq_4bit(layer.qweight)
|
|
layer.qweight = torch.nn.Parameter(qweight, requires_grad=False)
|
|
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
qweight = layer.qweight
|
|
scales = layer.scales
|
|
qzeros = layer.qzeros
|
|
pack_factor = self.quant_config.pack_factor
|
|
group_size = self.quant_config.group_size
|
|
|
|
return torch.ops.vllm._apply_awq(x, qweight, scales, qzeros,
|
|
bias, pack_factor, group_size)
|