Sync from v0.13
This commit is contained in:
320
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
320
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported."
|
||||
)
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported."
|
||||
)
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24"
|
||||
)
|
||||
|
||||
if is_marlin_24_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin24 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}"
|
||||
)
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if (
|
||||
self.quant_config.group_size != -1
|
||||
and input_size_per_partition % self.quant_config.group_size != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2
|
||||
)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError("Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||
output_size_per_partition
|
||||
* self.quant_config.tile_size
|
||||
// self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Meta
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
device="cuda",
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (
|
||||
1
|
||||
if self.quant_config.group_size == -1
|
||||
else input_size_per_partition // self.quant_config.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // self.quant_config.min_n_threads
|
||||
) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(
|
||||
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("B_24", qweight)
|
||||
layer.register_parameter("B_meta", meta)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_config.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user