init
This commit is contained in:
28
vllm/model_executor/layers/quantization/__init__.py
Normal file
28
vllm/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"marlin": MarlinConfig,
|
||||
"smoothquant": SmoothQuantConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
]
|
||||
170
vllm/model_executor/layers/quantization/awq.py
Normal file
170
vllm/model_executor/layers/quantization/awq.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
def get_linear_method(self) -> "AWQLinearMethod":
|
||||
return AWQLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
scales = weights["scales"]
|
||||
qzeros = weights["qzeros"]
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
# TODO align
|
||||
"""
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
"""
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
64
vllm/model_executor/layers/quantization/base_config.py
Normal file
64
vllm/model_executor/layers/quantization/base_config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_min_capability(self) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
218
vllm/model_executor/layers/quantization/gptq.py
Normal file
218
vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fractions import Fraction
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is supported for "
|
||||
f"GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
return cls(weight_bits, group_size, desc_act)
|
||||
|
||||
def get_linear_method(self) -> "GPTQLinearMethod":
|
||||
return GPTQLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
del output_size # Unused.
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
raise NotImplementedError()
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
g_idx = Parameter(
|
||||
torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"g_idx": g_idx,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
"exllama_state": exllama_state,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
|
||||
torch.int)
|
||||
else:
|
||||
weights["g_idx"] = None
|
||||
# TODO align
|
||||
"""
|
||||
weights["g_idx"] = torch.empty((1, 1), device="meta")
|
||||
"""
|
||||
weights["exllama_state"] = ExllamaState.READY
|
||||
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
|
||||
self.quant_config.weight_bits)
|
||||
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
|
||||
weights["qzeros"], weights["scales"],
|
||||
weights["g_idx"],
|
||||
weights["exllama_state"] == ExllamaState.READY,
|
||||
self.quant_config.weight_bits)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.reshape(out_shape)
|
||||
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) is supported for "
|
||||
f"Marlin, but got group_size of {self.group_size}")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MarlinConfig(group_size={self.group_size}"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(group_size)
|
||||
|
||||
def get_linear_method(self) -> "MarlinLinearMethod":
|
||||
return MarlinLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class MarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
del output_size # Unused.
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
"marlin_tile_size": self.quant_config.tile_size,
|
||||
},
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"input_dim": None if input_groups == 1 else 0,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
requires_grad=False)
|
||||
|
||||
return {
|
||||
"B": qweight,
|
||||
"s": scales,
|
||||
"workspace": workspace,
|
||||
}
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = weights["B"]
|
||||
scales = weights["s"]
|
||||
workspace = weights["workspace"]
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
111
vllm/model_executor/layers/quantization/smoothquant.py
Normal file
111
vllm/model_executor/layers/quantization/smoothquant.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
|
||||
|
||||
|
||||
class SmoothQuantConfig(QuantizationConfig):
|
||||
"""Config class for SmoothQuant
|
||||
Reference: https://github.com/mit-han-lab/smoothquant
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
quant_type: str = "tensor"
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.quant_type = quant_type
|
||||
|
||||
if self.weight_bits != 8:
|
||||
raise ValueError(
|
||||
"Currently, only w8a8 quantization is supported for "
|
||||
f"SmoothQuant, but got {self.weight_bits} bits.")
|
||||
if self.quant_type != "tensor":
|
||||
raise ValueError(
|
||||
"Currently, only tensor wise quantization is supported for "
|
||||
f"SmoothQuant, but got {self.quant_type} type quantization.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, "
|
||||
f"quant_type={self.quant_type})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "smoothquant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.float]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
quant_type = cls.get_from_keys(config, ["quant_type", "q_type"])
|
||||
return cls(weight_bits, quant_type)
|
||||
|
||||
def get_linear_method(self) -> "SmoothLinearMethod":
|
||||
return SmoothLinearMethod(world_size=get_tensor_model_parallel_world_size())
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SmoothLinearMethod(LinearMethodBase):
|
||||
def __init__(self, world_size, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.apply_dequant_after_row = world_size > 1
|
||||
self.dtpye = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
self.dtpye = params_dtype
|
||||
return {"weight": weight}
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
dequant_scale: float = 1.0,
|
||||
is_row: bool = False,
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
weight = weights["weight"]
|
||||
y = torch.empty((x.shape[0], weight.shape[0]),dtype=torch.int32,device=x.device)
|
||||
ops.linear_a8_w8_o32_(x, weight, y)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
if is_row and self.apply_dequant_after_row:
|
||||
# when tp > 1, duquant first(To improve accuracy?)
|
||||
out = torch.empty_like(y, dtype=self.dtpye)
|
||||
ops.dequant(out, y, scale, dequant_scale)
|
||||
y = out
|
||||
return y
|
||||
129
vllm/model_executor/layers/quantization/squeezellm.py
Normal file
129
vllm/model_executor/layers/quantization/squeezellm.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
"""Config class for SqueezeLLM.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2306.07629
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "squeezellm"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SqueezeLLM.
|
||||
|
||||
Args:
|
||||
quant_config: The SqueezeLLM quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
lookup_table = Parameter(
|
||||
torch.empty(
|
||||
output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(lookup_table, {
|
||||
"output_dim": 0,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"lookup_table": lookup_table,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
if is_hip():
|
||||
out_f = torch.zeros(out_shape, dtype=torch.float)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
||||
out = out_f.to(dtype=torch.float16)
|
||||
else:
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
Reference in New Issue
Block a user