init
This commit is contained in:
158
model_executor/layers/quantization/__init__.py
Normal file
158
model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, get_args
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
"aqlm",
|
||||
"awq",
|
||||
"gptq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"marlin",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
# "gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"qqq",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"neuron_quant",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||
|
||||
|
||||
def register_quantization_config(quantization: str):
|
||||
"""Register a customized vllm quantization config.
|
||||
|
||||
When a quantization method is not supported by vllm, you can register a customized
|
||||
quantization config to support it.
|
||||
|
||||
Args:
|
||||
quantization (str): The quantization method name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
>>> from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
>>>
|
||||
>>> @register_quantization_config("my_quant")
|
||||
... class MyQuantConfig(QuantizationConfig):
|
||||
... pass
|
||||
>>>
|
||||
>>> get_quantization_config("my_quant")
|
||||
<class 'MyQuantConfig'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(quant_config_cls):
|
||||
if quantization in QUANTIZATION_METHODS:
|
||||
raise ValueError(
|
||||
f"The quantization method `{quantization}` is already exists.")
|
||||
if not issubclass(quant_config_cls, QuantizationConfig):
|
||||
raise ValueError("The quantization config must be a subclass of "
|
||||
"`QuantizationConfig`.")
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
|
||||
QUANTIZATION_METHODS.append(quantization)
|
||||
return quant_config_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
# lazy import to avoid triggering `torch.compile` too early
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
from .aqlm import AQLMConfig
|
||||
from .auto_round import AutoRoundConfig
|
||||
from .awq import AWQConfig
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
from .bitblas import BitBLASConfig
|
||||
from .bitsandbytes import BitsAndBytesConfig
|
||||
from .compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig)
|
||||
from .deepspeedfp import DeepSpeedFPConfig
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
from .fp8 import Fp8Config
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_bitblas import GPTQBitBLASConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .hqq_marlin import HQQMarlinConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"marlin": MarlinConfig,
|
||||
"bitblas": BitBLASConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQConfig,
|
||||
"gptq_bitblas": GPTQBitBLASConfig,
|
||||
"awq_marlin": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"qqq": QQQConfig,
|
||||
"hqq": HQQMarlinConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"neuron_quant": NeuronQuantConfig,
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
return method_to_config[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"QuantizationMethods",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
376
model_executor/layers/quantization/aqlm.py
Normal file
376
model_executor/layers/quantization/aqlm.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
|
||||
# and https://arxiv.org/pdf/2401.06118.pdf
|
||||
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
def get_int_dtype(nbits: int) -> torch.dtype:
|
||||
if nbits <= 8:
|
||||
return torch.int8
|
||||
if nbits <= 16:
|
||||
return torch.int16
|
||||
if nbits <= 32:
|
||||
return torch.int32
|
||||
if nbits <= 64:
|
||||
return torch.int64
|
||||
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
|
||||
return data.to(torch.int64) % (2**nbits)
|
||||
|
||||
|
||||
def dequantize_weight(codes: torch.Tensor,
|
||||
codebooks: torch.Tensor,
|
||||
scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Decode float weights from quantization codes. Differentiable.
|
||||
:param codes: tensor of integer quantization codes, shape
|
||||
[*dims, num_out_groups, num_in_groups, num_codebooks]
|
||||
:param codebooks: tensor of vectors for each quantization code,
|
||||
[num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
:param scales: weight will be multiplied by this factor, must be
|
||||
broadcastble with
|
||||
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
|
||||
:return: reconstructed weight tensor of shape
|
||||
[*dims, num_in_groups*group_size]
|
||||
"""
|
||||
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
|
||||
num_codebooks, codebook_size, out_group_size, in_group_size = \
|
||||
codebooks.shape
|
||||
out_features = num_out_groups * out_group_size
|
||||
in_features = num_in_groups * in_group_size
|
||||
codebook_offsets = torch.arange(
|
||||
0, num_codebooks * codebook_size, codebook_size,
|
||||
device=codes.device) # shape: [num_codebooks]
|
||||
reconstructed_weight_flat = F.embedding_bag(
|
||||
codes.flatten(0, -2) + codebook_offsets,
|
||||
codebooks.flatten(0, 1).flatten(-2, -1),
|
||||
mode="sum"
|
||||
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
|
||||
# * in_group_size]
|
||||
|
||||
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
|
||||
list(codes.shape[:-3]) +
|
||||
[num_out_groups, num_in_groups, out_group_size, in_group_size])
|
||||
if scales is not None:
|
||||
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
|
||||
scales)
|
||||
return reconstructed_weight_groupwise.swapaxes(
|
||||
-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
|
||||
|
||||
|
||||
def dequantize_gemm(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
dequantized_weight = dequantize_weight(
|
||||
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
|
||||
codebooks,
|
||||
scales,
|
||||
)
|
||||
return F.linear(input, dequantized_weight, bias)
|
||||
|
||||
|
||||
# Generic dequantization, slow but flexible.
|
||||
def generic_dequantize_gemm(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (scales.shape[0], )
|
||||
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||
num_outputs = len(output_partition_sizes)
|
||||
|
||||
# break the inputs and codebooks apart then combine the outputs.
|
||||
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
|
||||
# multiply at the end.
|
||||
num_codebooks = codebooks.shape[0] // num_outputs
|
||||
assert (scales.shape[0] == codes.shape[0])
|
||||
assert (sum(output_partition_sizes) == scales.shape[0])
|
||||
output_offset = 0
|
||||
codebooks_offset = 0
|
||||
for output_size in output_partition_sizes:
|
||||
shard_output = dequantize_gemm(
|
||||
input, codes.narrow(0, output_offset, output_size),
|
||||
codebooks.narrow(0, codebooks_offset, num_codebooks),
|
||||
scales.narrow(0, output_offset, output_size), None
|
||||
if bias is None else bias.narrow(0, output_offset, output_size))
|
||||
|
||||
output_slice = output.narrow(-1, output_offset, output_size)
|
||||
assert (output_slice.shape == shard_output.shape)
|
||||
output_slice.copy_(shard_output)
|
||||
output_offset += output_size
|
||||
codebooks_offset += num_codebooks
|
||||
return output
|
||||
|
||||
|
||||
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
|
||||
# at 6 and 9 times faster than the generic version above, respectively.
|
||||
def optimized_dequantize_gemm(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
if bias is None:
|
||||
# scaling the output is fastest, so we do that when possible.
|
||||
output = F.linear(input, weights, bias)
|
||||
orig_shape = output.shape
|
||||
flattened_output = output.view(-1, output.size(-1))
|
||||
f_scales = scales.view(-1, scales.shape[0])
|
||||
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
||||
flattened_output *= b_scales
|
||||
return output.view(orig_shape)
|
||||
else:
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||
-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
class AQLMConfig(QuantizationConfig):
|
||||
"""Config class for AQLM.
|
||||
|
||||
Reference: https://github.com/Vahe1994/AQLM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_group_size: int,
|
||||
nbits_per_codebook: int,
|
||||
num_codebooks: int,
|
||||
out_group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_group_size = in_group_size
|
||||
self.nbits_per_codebook = nbits_per_codebook
|
||||
self.num_codebooks = num_codebooks
|
||||
self.out_group_size = out_group_size
|
||||
|
||||
# out_group_size > 1 is untested, and probably won't work as-is.
|
||||
assert (self.out_group_size == 1)
|
||||
self.pack_factor = (self.in_group_size * self.out_group_size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AQLMConfig(in_group_size={self.in_group_size}, "
|
||||
f"nbits_per_codebook={self.nbits_per_codebook}, "
|
||||
f"num_codebooks={self.num_codebooks}, "
|
||||
f"out_group_size={self.out_group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "aqlm"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
|
||||
in_group_size = cls.get_from_keys(config, ["in_group_size"])
|
||||
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
|
||||
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
|
||||
out_group_size = cls.get_from_keys(config, ["out_group_size"])
|
||||
return cls(in_group_size, nbits_per_codebook, num_code_books,
|
||||
out_group_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AQLMLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return AQLMLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class AQLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AQLM.
|
||||
|
||||
Args:
|
||||
quant_config: The AQLM quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AQLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
del output_size # Unused.
|
||||
del input_size # Unused.
|
||||
|
||||
if params_dtype != torch.half:
|
||||
raise ValueError("Only half is currently supported by aqlm")
|
||||
if input_size_per_partition % self.quant_config.in_group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.out_group_size != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
codes = Parameter(
|
||||
torch.empty(
|
||||
# There could actually be two pack factors, one along input and
|
||||
# one along output, but we don't currently support
|
||||
# out_group_size, and only the one along output needs to be
|
||||
# marked with "packed_dim" in order for QKVLinear to work.
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.quant_config.num_codebooks,
|
||||
dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
codes,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
},
|
||||
)
|
||||
|
||||
codebooks = Parameter(
|
||||
torch.empty(
|
||||
self.quant_config.num_codebooks * len(output_partition_sizes),
|
||||
2**self.quant_config.nbits_per_codebook,
|
||||
self.quant_config.out_group_size,
|
||||
self.quant_config.in_group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
codebooks,
|
||||
{
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
"is_metadata": True,
|
||||
"output_partition_sizes": output_partition_sizes
|
||||
},
|
||||
)
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
(
|
||||
output_size_per_partition //
|
||||
self.quant_config.out_group_size,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.out_group_size
|
||||
},
|
||||
)
|
||||
|
||||
layer.register_parameter("codes", codes)
|
||||
set_weight_attrs(codes, extra_weight_attrs)
|
||||
layer.register_parameter("codebooks", codebooks)
|
||||
set_weight_attrs(codebooks, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
codebooks = layer.codebooks
|
||||
codes = layer.codes
|
||||
scales = layer.scales
|
||||
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
|
||||
[])
|
||||
|
||||
nbooks = codes.shape[2]
|
||||
ingroups = codebooks.shape[3]
|
||||
outgroups = codebooks.shape[2]
|
||||
bits = codebooks.shape[1]
|
||||
|
||||
# We support these formats with dedicated gemm and decompression
|
||||
# kernels.
|
||||
if ingroups == 8 and outgroups == 1 and (
|
||||
(bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
|
||||
|
||||
# thresholds determined by timings on an A6000, one GPU
|
||||
use_gemv = math.prod(x.shape[:-1]) <= 6
|
||||
|
||||
return ops.aqlm_gemm(
|
||||
x,
|
||||
codes,
|
||||
codebooks,
|
||||
scales,
|
||||
output_partition_sizes,
|
||||
bias,
|
||||
) if use_gemv else optimized_dequantize_gemm(
|
||||
x,
|
||||
codes,
|
||||
codebooks,
|
||||
scales,
|
||||
output_partition_sizes,
|
||||
bias,
|
||||
)
|
||||
|
||||
# fall back all unoptimized formats
|
||||
return generic_dequantize_gemm(
|
||||
x,
|
||||
codes,
|
||||
codebooks,
|
||||
scales,
|
||||
output_partition_sizes,
|
||||
bias,
|
||||
)
|
||||
310
model_executor/layers/quantization/auto_round.py
Normal file
310
model_executor/layers/quantization/auto_round.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AutoRoundConfig(QuantizationConfig):
|
||||
"""Config class for AutoRound.
|
||||
Reference: https://arxiv.org/pdf/2309.05516
|
||||
"""
|
||||
|
||||
SUPPORTED_BITS = {2, 3, 4, 8}
|
||||
SUPPORTED_DTYPES = {"int"}
|
||||
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
||||
SUPPORTED_BACKENDS = {
|
||||
"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
sym: bool = True,
|
||||
packing_format: str = "auto_round:auto_gptq",
|
||||
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
|
||||
extra_config: Optional[dict[str, Any]] = None,
|
||||
data_type: str = "int",
|
||||
backend: str = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if weight_bits not in self.SUPPORTED_BITS:
|
||||
raise ValueError(f"Unsupported weight_bits: {weight_bits}, "
|
||||
f"currently only support {self.SUPPORTED_BITS}")
|
||||
if data_type not in self.SUPPORTED_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported data_type: {data_type},"
|
||||
f" currently only support {self.SUPPORTED_DTYPES}")
|
||||
if packing_format not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported packing_format: {packing_format}, "
|
||||
f"currently only support {self.SUPPORTED_FORMATS}")
|
||||
if backend not in self.SUPPORTED_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Unsupported backend: {backend}, "
|
||||
f"currently only support {self.SUPPORTED_BACKENDS}")
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.sym = sym
|
||||
self.packing_format = packing_format
|
||||
self.block_name_to_quantize = (block_name_to_quantize.split(",") if
|
||||
isinstance(block_name_to_quantize, str)
|
||||
else block_name_to_quantize)
|
||||
self.extra_config = extra_config
|
||||
self.data_type = data_type
|
||||
self.backend = backend
|
||||
self.pack_factor = Fraction(32, weight_bits)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AutoRoundConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, sym={self.sym})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "auto-round"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantization_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
|
||||
return cls(
|
||||
weight_bits=cls.get_from_keys(config, ["bits"]),
|
||||
group_size=cls.get_from_keys(config, ["group_size"]),
|
||||
sym=cls.get_from_keys(config, ["sym"]),
|
||||
packing_format=cls.get_from_keys_or(config, ["packing_format"],
|
||||
"auto_round:auto_gptq"),
|
||||
block_name_to_quantize=cls.get_from_keys_or(
|
||||
config, ["block_name_to_quantize", "to_quant_block_names"],
|
||||
None),
|
||||
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
|
||||
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
|
||||
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"],
|
||||
"auto"),
|
||||
)
|
||||
|
||||
def get_layer_config(self, layer, layer_name: str):
|
||||
# Priority: extra_config > block_name_to_quantize > type fallback
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
cfg = self.extra_config[layer_name]
|
||||
return cfg.get("bits", self.weight_bits), cfg.get(
|
||||
"group_size", self.group_size), cfg.get("sym", self.sym)
|
||||
|
||||
quantized = True
|
||||
if self.block_name_to_quantize:
|
||||
quantized = any(
|
||||
layer_name.startswith(name)
|
||||
for name in self.block_name_to_quantize)
|
||||
elif isinstance(layer, ParallelLMHead):
|
||||
quantized = False
|
||||
|
||||
return (self.weight_bits, self.group_size,
|
||||
self.sym) if quantized else (16, -1, True)
|
||||
|
||||
def check_quantized(self, weight_bits: int) -> bool:
|
||||
return weight_bits < 16
|
||||
|
||||
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix, layer.__class__.__name__, weight_bits, group_size,
|
||||
sym)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
AWQ_TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
use_marlin = (weight_bits
|
||||
in AWQ_TYPE_MAP) and check_marlin_supported(
|
||||
AWQ_TYPE_MAP[weight_bits], group_size, not sym)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size)
|
||||
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
|
||||
quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[])
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQConfig, AWQLinearMethod)
|
||||
quant_args = AWQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"zero_point": not sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return AWQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return AWQLinearMethod(quant_args)
|
||||
return None
|
||||
|
||||
def apply_gptq_quant_layer(self,
|
||||
layer,
|
||||
prefix: str,
|
||||
backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer)
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix, layer.__class__.__name__, weight_bits, group_size,
|
||||
sym)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
GPTQ_TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
|
||||
and check_marlin_supported(
|
||||
GPTQ_TYPE_MAP[(weight_bits, sym)],
|
||||
group_size,
|
||||
has_zp=not sym))
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size)
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
|
||||
quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={})
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.gptq import (
|
||||
GPTQConfig, GPTQLinearMethod)
|
||||
quant_args = GPTQConfig(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={})
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"sym": sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return GPTQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return GPTQLinearMethod(quant_args)
|
||||
|
||||
return None
|
||||
|
||||
def apply_ipex_quant_layer(self, layer, prefix: str):
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if "awq" in self.packing_format:
|
||||
config = IPEXConfig(method="awq",
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size)
|
||||
return IPEXAWQLinearMethod(config)
|
||||
elif "gptq" in self.packing_format:
|
||||
config = IPEXConfig(method="gptq",
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size)
|
||||
return IPEXGPTQLinearMethod(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ipex backend only supports awq "
|
||||
f"and gtpq format,but got {self.packing_format}")
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
if (current_platform.is_cpu() or current_platform.is_xpu()
|
||||
or self.backend == "ipex"):
|
||||
return self.apply_ipex_quant_layer(layer, prefix)
|
||||
if "gptq" in self.packing_format or "gptq" in self.backend:
|
||||
return self.apply_gptq_quant_layer(layer, prefix)
|
||||
if "awq" in self.packing_format or "awq" in self.backend:
|
||||
return self.apply_awq_quant_layer(layer, prefix)
|
||||
241
model_executor/layers/quantization/awq.py
Normal file
241
model_executor/layers/quantization/awq.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
def _apply_awq_fake(x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
pack_factor: int,
|
||||
group_size: int) -> torch.Tensor:
|
||||
out_shape = ()
|
||||
if group_size % 32:
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
else:
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[0], ))
|
||||
return torch.empty(out_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
def _apply_awq(x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
pack_factor: int,
|
||||
group_size: int) -> torch.Tensor:
|
||||
out_shape = ()
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = torch.empty(0)
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
# if (FP16_MATMUL_HEURISTIC_CONDITION and reshaped_x.dtype == torch.half) or self.quant_config.group_size != 128:
|
||||
if group_size % 32:
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
num_out_channel = qweight.shape[0]
|
||||
out_shape = (x.shape[:-1] + (num_out_channel, ))
|
||||
temp_space = torch.empty(0, dtype=torch.float32, device=x.device)
|
||||
if reshaped_x.dtype == torch.bfloat16:
|
||||
temp_space = torch.zeros(reshaped_x.shape[0], num_out_channel,
|
||||
dtype=torch.float32, device=x.device)
|
||||
out = ops.awq_gemm(reshaped_x, qweight, qzeros, scales,
|
||||
pack_factor, temp_space,
|
||||
True if reshaped_x.dtype == torch.bfloat16 else False)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_awq",
|
||||
op_func=_apply_awq,
|
||||
mutates_args=[],
|
||||
fake_impl=_apply_awq_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
# warmup
|
||||
if self.quant_config.group_size % 32:
|
||||
pass
|
||||
else:
|
||||
qweight = ops.awq_to_gptq_4bit(layer.qweight)
|
||||
layer.qweight = torch.nn.Parameter(qweight, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
return torch.ops.vllm._apply_awq(x, qweight, scales, qzeros,
|
||||
bias, pack_factor, group_size)
|
||||
519
model_executor/layers/quantization/awq_marlin.py
Normal file
519
model_executor/layers/quantization/awq_marlin.py
Normal file
@@ -0,0 +1,519 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.zero_point)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
|
||||
modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "awq_marlin")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_marlin"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_marlin for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if the layer is supported by AWQMarlin.
|
||||
if not check_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
|
||||
prefix,
|
||||
)
|
||||
return AWQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or zero_point is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=zero_point)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
|
||||
# TODO: Update this docs
|
||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_awq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed":
|
||||
True,
|
||||
"quant_method":
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
})
|
||||
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(torch.empty(num_experts,
|
||||
intermediate_size_per_partition,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size_per_partition * 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(
|
||||
torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w13_qzeros,
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w2_qzeros,
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace)
|
||||
320
model_executor/layers/quantization/awq_triton.py
Normal file
320
model_executor/layers/quantization/awq_triton.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_dequantize_kernel(
|
||||
qweight_ptr, # quantized matrix
|
||||
scales_ptr, # scales, per group
|
||||
zeros_ptr, # zeros, per group
|
||||
group_size, # Should always be one of the supported group sizes
|
||||
result_ptr, # Output matrix
|
||||
num_cols, # input num cols in qweight
|
||||
num_rows, # input num rows in qweight
|
||||
BLOCK_SIZE_X: tl.constexpr,
|
||||
BLOCK_SIZE_Y: tl.constexpr):
|
||||
# Setup the pids.
|
||||
pid_x = tl.program_id(axis=0)
|
||||
pid_y = tl.program_id(axis=1)
|
||||
|
||||
# Compute offsets and masks for qweight_ptr.
|
||||
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
|
||||
|
||||
masks_y = offsets_y < num_rows
|
||||
masks_x = offsets_x < num_cols
|
||||
|
||||
masks = masks_y[:, None] & masks_x[None, :]
|
||||
|
||||
# Compute offsets and masks for result output ptr.
|
||||
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(
|
||||
0, BLOCK_SIZE_X * 8)
|
||||
result_offsets = (8 * num_cols * result_offsets_y[:, None] +
|
||||
result_offsets_x[None, :])
|
||||
|
||||
result_masks_y = result_offsets_y < num_rows
|
||||
result_masks_x = result_offsets_x < num_cols * 8
|
||||
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
||||
|
||||
# Load the weights.
|
||||
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
|
||||
tl.arange(0, 4)[:, None]).reshape(8)
|
||||
|
||||
# Use this to compute a set of shifts that can be used to unpack and
|
||||
# reorder the values in iweights and zeros.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
iweights = (iweights >> shifts) & 0xF
|
||||
|
||||
# Compute zero offsets and masks.
|
||||
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
|
||||
|
||||
zero_masks_y = zero_offsets_y < num_rows // group_size
|
||||
zero_masks_x = zero_offsets_x < num_cols
|
||||
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
||||
|
||||
# Load the zeros.
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
|
||||
# Compute scale offsets and masks.
|
||||
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
|
||||
tl.arange(0, BLOCK_SIZE_X * 8))
|
||||
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
|
||||
scale_offsets_x[None, :])
|
||||
scale_masks_y = scale_offsets_y < num_rows // group_size
|
||||
scale_masks_x = scale_offsets_x < num_cols * 8
|
||||
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
||||
|
||||
# Load the scales.
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Dequantize.
|
||||
iweights = (iweights - zeros) * scales
|
||||
iweights = iweights.to(result_ptr.type.element_ty)
|
||||
|
||||
# Finally, store.
|
||||
tl.store(result_ptr + result_offsets, iweights, result_masks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
group_size, BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_z = tl.program_id(1)
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = c_ptr.type.element_ty
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# accumulator = tl.arange(0, BLOCK_SIZE_N)
|
||||
# accumulator = tl.broadcast_to(accumulator[None, :],
|
||||
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
|
||||
# accumulator = accumulator & 0x0
|
||||
# accumulator = accumulator.to(accumulator_dtype)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
|
||||
dtype=accumulator_dtype)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
|
||||
tl.arange(0, 4)[:, None]).reshape(8)
|
||||
|
||||
# Create the necessary shifts to use to unpack.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :],
|
||||
(BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_bn = offsets_bn < N // 8
|
||||
|
||||
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_zn = offsets_zn < N // 8
|
||||
|
||||
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
masks_sn = offsets_sn < N
|
||||
|
||||
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
|
||||
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
|
||||
# block_offset = BLOCK_SIZE_K * SPLIT_K
|
||||
# for k in range(0, (K + block_offset - 1) // (block_offset)):
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
|
||||
# Dequantize b.
|
||||
offsets_szk = (
|
||||
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
|
||||
tl.arange(0, 1))
|
||||
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
|
||||
masks_zk = offsets_szk < K // group_size
|
||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||
zeros_ptrs = zeros_ptr + offsets_z
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
|
||||
masks_sk = offsets_szk < K // group_size
|
||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||
scales_ptrs = scales_ptr + offsets_s
|
||||
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
b = (b >> shifts) & 0xF
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
b = (b - zeros) * scales
|
||||
b = b.to(c_ptr.type.element_ty)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K * SPLIT_K
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
|
||||
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# qweights - [K , M // 8], int32
|
||||
# scales - [K // G, M ], float16
|
||||
# zeros - [K // G, M // 8], int32
|
||||
def awq_dequantize_triton(qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
block_size_x: int = 32,
|
||||
block_size_y: int = 32) -> torch.Tensor:
|
||||
K = qweight.shape[0]
|
||||
M = scales.shape[1]
|
||||
group_size = qweight.shape[0] // scales.shape[0]
|
||||
|
||||
assert K > 0 and M > 0
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == M
|
||||
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
# Result tensor:
|
||||
# number of rows = same as input tensor
|
||||
# number of cols = 8 x input tensor num cols
|
||||
result = torch.empty(qweight.shape[0],
|
||||
qweight.shape[1] * 8,
|
||||
device=qweight.device,
|
||||
dtype=scales.dtype)
|
||||
|
||||
Y = qweight.shape[0] # num rows
|
||||
X = qweight.shape[1] # num cols
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(X, META['BLOCK_SIZE_X']),
|
||||
triton.cdiv(Y, META['BLOCK_SIZE_Y']),
|
||||
)
|
||||
awq_dequantize_kernel[grid](qweight,
|
||||
scales,
|
||||
zeros,
|
||||
group_size,
|
||||
result,
|
||||
X,
|
||||
Y,
|
||||
BLOCK_SIZE_X=block_size_x,
|
||||
BLOCK_SIZE_Y=block_size_y)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# qweight - [K, N // 8]
|
||||
# qzeros - [K // G, N // 8]
|
||||
# scales - [K // G, N]
|
||||
# split_k_iters - parallelism along K-dimension, int, power of 2.
|
||||
def awq_gemm_triton(input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = qweight.shape[1] * 8
|
||||
group_size = qweight.shape[0] // qzeros.shape[0]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
|
||||
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == N
|
||||
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
|
||||
assert split_k_iters <= 32
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
N, META['BLOCK_SIZE_N']),
|
||||
split_k_iters,
|
||||
)
|
||||
|
||||
result = torch.zeros((split_k_iters, M, N),
|
||||
dtype=scales.dtype,
|
||||
device=input.device)
|
||||
|
||||
# A = input, B = qweight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
awq_gemm_kernel[grid](input,
|
||||
qweight,
|
||||
result,
|
||||
qzeros,
|
||||
scales,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
group_size,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
SPLIT_K=split_k_iters)
|
||||
|
||||
result = result.sum(0)
|
||||
|
||||
return result
|
||||
154
model_executor/layers/quantization/base_config.py
Normal file
154
model_executor/layers/quantization/base_config.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, *weight_args,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Not required functions
|
||||
def embedding(self, layer: torch.nn.Module, *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Gather embeddings in the layer based on indices in the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
||||
None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return (class_embedding is not None
|
||||
and class_embedding is not base_embedding)
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
this method should only be overwritten by subclasses in exceptional
|
||||
circumstances
|
||||
"""
|
||||
if(user_quant != None):
|
||||
return user_quant
|
||||
else:
|
||||
return hf_quant_cfg["quant_method"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str],
|
||||
default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
return QuantizationConfig.get_from_keys(config, keys)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional[QuantizeMethodBase]:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
layer: The layer for the quant method.
|
||||
prefix: The full name of the layer in the state dict
|
||||
Returns:
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
return None
|
||||
461
model_executor/layers/quantization/bitblas.py
Normal file
461
model_executor/layers/quantization/bitblas.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
|
||||
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASConfig(QuantizationConfig):
|
||||
"""Config class for BitBLAS.
|
||||
|
||||
Reference: https://github.com/Microsoft/BitBLAS
|
||||
"""
|
||||
TORCH_DTYPE = torch.float16
|
||||
STORAGE_DTYPE = "int8" # assume int8 storage
|
||||
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# gptq_with_bitblas prefer "quantized implementation"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: Optional[int],
|
||||
desc_act: Optional[bool],
|
||||
is_sym: Optional[bool],
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
storage_dtype = self.STORAGE_DTYPE
|
||||
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: dict[str, Any],
|
||||
keys: list[str],
|
||||
default: Any = None) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"], -1)
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"], False)
|
||||
is_sym = cls.get_from_keys(config, ["sym"], False)
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||
or hf_quant_cfg.get("is_bitblas_format", False))
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "bitblas")
|
||||
|
||||
if is_bitblas_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. Using {} kernel.".
|
||||
format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return BitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class BitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The BitBLAS quantization config.
|
||||
"""
|
||||
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
|
||||
# Instead of BITBLAS_OPTIMIZE_FEATURES
|
||||
# If you want to high contiguous batching
|
||||
# performance
|
||||
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING = True
|
||||
BITBLAS_DTYPES = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
|
||||
def __init__(self, quant_config: BitBLASConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing quantized
|
||||
weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_size_per_partition: The size of the output partition.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or if the
|
||||
input size per partition is not divisible by the group size in
|
||||
`quant_config`.
|
||||
"""
|
||||
del input_size, output_size # Unused arguments.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype not in self.quant_config.get_supported_act_dtypes():
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
group_size = self.quant_config.group_size
|
||||
if group_size is None:
|
||||
group_size = -1
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (group_size != -1 and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({group_size}).")
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self._configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
enable_tuning=self.ENABLE_TUNING,
|
||||
bias=False,
|
||||
layout="nt",
|
||||
bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
# Initialize quantized weights with dimensions
|
||||
# Quantized 4Bit weights packed.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
self.bitblas_matmul.retrieve_weight_shape(),
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
|
||||
if self.bitblas_matmul.propagate_b else None),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Compute the number of input groups for channel-wise quantization.
|
||||
input_groups = (1 if group_size == -1 else input_size_per_partition //
|
||||
group_size)
|
||||
|
||||
# Initialize scales and zeros for the quantized weights.
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
if self.quant_config.zeros_mode == "quantized":
|
||||
zeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
else:
|
||||
zeros = BasevLLMParameter(
|
||||
torch.empty(output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
# Set attributes to indicate how scales and zeros are applied.
|
||||
set_weight_attrs(zeros, {
|
||||
"input_dim": None if input_groups == 1 else 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("zeros", zeros)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.create_weights_gptq(layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size,
|
||||
output_size, params_dtype,
|
||||
**extra_weight_attrs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
out_dtype="float16",
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = self.quant_config.group_size
|
||||
zeros_mode = self.quant_config.zeros_mode
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if self.quant_config.is_sym:
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=out_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=self.quant_config.STORAGE_DTYPE,
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNED_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNED_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created."
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = (
|
||||
f"BitBLAS Operator {config} found in global_operator_cache.")
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.zeros
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
if self.quant_config.is_sym:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
|
||||
else:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
|
||||
def apply(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.apply_gptq(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
396
model_executor/layers/quantization/bitsandbytes.py
Normal file
396
model_executor/layers/quantization/bitsandbytes.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
"""Config class for BitsAndBytes Quantization.
|
||||
|
||||
Reference: https://arxiv.org/abs/2305.14314
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_in_8bit: bool = False,
|
||||
load_in_4bit: bool = True,
|
||||
bnb_4bit_compute_dtype: str = "float32",
|
||||
bnb_4bit_quant_storage: str = "uint8",
|
||||
bnb_4bit_quant_type: str = "fp4",
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[list[str]] = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
if self.bnb_4bit_quant_storage not in ["uint8"]:
|
||||
raise ValueError("Unsupported bnb_4bit_quant_storage: "
|
||||
f"{self.bnb_4bit_quant_storage}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
||||
f"load_in_4bit={self.load_in_4bit}, "
|
||||
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
||||
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
|
||||
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
value = cls.get_from_keys(config, keys)
|
||||
return value if value is not None else default_value
|
||||
except ValueError:
|
||||
return default_value
|
||||
|
||||
load_in_8bit = get_safe_value(config, ["load_in_8bit"],
|
||||
default_value=False)
|
||||
load_in_4bit = get_safe_value(config, ["load_in_4bit"],
|
||||
default_value=True)
|
||||
bnb_4bit_compute_dtype = get_safe_value(config,
|
||||
["bnb_4bit_compute_dtype"],
|
||||
default_value="float32")
|
||||
bnb_4bit_quant_storage = get_safe_value(config,
|
||||
["bnb_4bit_quant_storage"],
|
||||
default_value="uint8")
|
||||
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
|
||||
default_value="fp4")
|
||||
bnb_4bit_use_double_quant = get_safe_value(
|
||||
config, ["bnb_4bit_use_double_quant"], default_value=False)
|
||||
llm_int8_enable_fp32_cpu_offload = get_safe_value(
|
||||
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
|
||||
llm_int8_has_fp16_weight = get_safe_value(config,
|
||||
["llm_int8_has_fp16_weight"],
|
||||
default_value=False)
|
||||
llm_int8_skip_modules = get_safe_value(config,
|
||||
["llm_int8_skip_modules"],
|
||||
default_value=[])
|
||||
llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
|
||||
default_value=6.0)
|
||||
|
||||
return cls(
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_4bit=load_in_4bit,
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
||||
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
|
||||
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
substr_check = any(module_name in components
|
||||
for module_name in llm_int8_skip_modules)
|
||||
|
||||
# Allow certain layers to not be quantized
|
||||
set_components = set(".".join(components[:i + 1])
|
||||
for i in range(len(components)))
|
||||
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
|
||||
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
|
||||
|
||||
return substr_check or prefix_check
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
try:
|
||||
import bitsandbytes
|
||||
if bitsandbytes.__version__ < "0.45.3":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.45.3.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
||||
"`pip install bitsandbytes>=0.45.3` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": 1,
|
||||
"use_bitsandbytes_8bit": True,
|
||||
"generation": 0
|
||||
})
|
||||
return qweight
|
||||
|
||||
def create_qweight_for_4bit():
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
|
||||
total_size = input_size_per_partition * sum(output_partition_sizes)
|
||||
if total_size % quant_ratio != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape.")
|
||||
|
||||
qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
|
||||
1,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True
|
||||
})
|
||||
return qweight
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
qweight = create_qweight_for_4bit()
|
||||
# Enable parameters to have the same name as in the BNB
|
||||
# checkpoint format.
|
||||
layer.register_parameter("weight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
return self._apply_8bit_weight(layer, x, bias)
|
||||
else:
|
||||
return self._apply_4bit_weight(layer, x, bias)
|
||||
|
||||
def _apply_8bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
generation = qweight.generation
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
||||
out = torch.empty(out_dim_0,
|
||||
out_dim_1,
|
||||
dtype=torch.float16,
|
||||
device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
|
||||
# in profile_run or the first generation of inference,
|
||||
# create new matmul_states
|
||||
if generation == 0 or generation == 1:
|
||||
matmul_states[i] = MatmulLtState()
|
||||
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
||||
matmul_states[i].SCB = quant_states[i].to(x.device)
|
||||
matmul_states[i].threshold = (
|
||||
self.quant_config.llm_int8_threshold)
|
||||
matmul_states[i].has_fp16_weights = (
|
||||
self.quant_config.llm_int8_has_fp16_weight)
|
||||
matmul_states[i].is_training = False
|
||||
if matmul_states[i].threshold > 0.0 and not matmul_states[
|
||||
i].has_fp16_weights:
|
||||
matmul_states[i].use_pool = True
|
||||
|
||||
new_x = bf_x.unsqueeze(0)
|
||||
|
||||
out[:, current_index:current_index + output_size] = matmul(
|
||||
new_x,
|
||||
qweight[offsets[i]:offsets[i + 1]],
|
||||
state=matmul_states[i])
|
||||
|
||||
current_index += output_size
|
||||
|
||||
# only update the matmul_states if it is not profile_run
|
||||
if (generation > 0
|
||||
and not self.quant_config.llm_int8_has_fp16_weight
|
||||
and matmul_states[i].CB is not None
|
||||
and matmul_states[i].CxB is not None):
|
||||
del matmul_states[i].CB
|
||||
qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
qweight.generation += 1
|
||||
|
||||
return out
|
||||
|
||||
def _apply_4bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
quant_states = qweight.bnb_quant_state
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
||||
out = torch.empty(out_dim_0,
|
||||
out_dim_1,
|
||||
dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _apply_bnb_4bit(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
quant_states = weight.bnb_quant_state
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
current_index += output_size
|
||||
|
||||
|
||||
def _apply_bnb_4bit_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="apply_bnb_4bit",
|
||||
op_func=_apply_bnb_4bit,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_apply_bnb_4bit_fake,
|
||||
)
|
||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -0,0 +1,670 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: dict[str, Any],
|
||||
ignore: list[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "compressed-tensors"
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=self.ignore,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
if scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: dict[str, Any]
|
||||
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict(), []
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
# an input_activations key with details about how the activations are
|
||||
# quantized, a weights key indicating how the weights are quantized,
|
||||
# and a list of targets under the `targets` key, dictating which
|
||||
# layers are impacted by the quantization details. The quantization
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target][
|
||||
"weights"] = QuantizationArgs.model_validate(
|
||||
quant_config.get("weights"))
|
||||
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
if is_activation_quantization_format(quant_format):
|
||||
input_activations = quant_config.get("input_activations")
|
||||
# The only case where we have activation quant supported
|
||||
# but no input_activations provided in the config
|
||||
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
||||
# there is an input_quant but it is ignored
|
||||
if not input_activations:
|
||||
assert target_scheme_map[target][
|
||||
"weights"].type == QuantizationType.FLOAT
|
||||
else:
|
||||
target_scheme_map[target][
|
||||
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
|
||||
quant_config.get("input_activations"))
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True,
|
||||
match_exact: bool = False) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
"""
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if match_exact:
|
||||
supported = capability == min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
"the current GPU. Required capability: ",
|
||||
f"{min_capability}. Current capability: {capability}.")
|
||||
else:
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
"""
|
||||
return False
|
||||
|
||||
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
|
||||
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_tensor_group_quant = (weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR_GROUP.value
|
||||
and input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR_GROUP.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
|
||||
is_group_size_16 = (weight_quant.group_size == 16
|
||||
and input_quant.group_size == 16)
|
||||
is_float_type = (weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT.value)
|
||||
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
|
||||
|
||||
return (is_tensor_group_quant and is_float_type and is_4_bits
|
||||
and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel):
|
||||
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value)
|
||||
is_symmetric = weight_quant.symmetric
|
||||
|
||||
is_group_size_16 = weight_quant.group_size == 16
|
||||
is_float_type = weight_quant.type == QuantizationType.FLOAT
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
|
||||
return (is_weight_only and is_tensor_group_quant and is_float_type
|
||||
and is_4_bits and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_tensor = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT)
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
])
|
||||
if not (is_floating_point and is_symmetric_weight and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.dynamic:
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_symmetric_activation = input_quant.symmetric
|
||||
is_per_tensor_activation = (
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if weight_quant.type != QuantizationType.FLOAT:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
])
|
||||
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_static = not weight_quant.dynamic
|
||||
|
||||
return (is_channel_group and input_quant_none and is_static)
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size)
|
||||
if (self.quant_format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(input_quant
|
||||
and not input_quant.dynamic))
|
||||
else:
|
||||
# note: input_quant will be present for converted models;
|
||||
# will be ignored during inference post loading
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=not input_quant.dynamic)
|
||||
|
||||
# note: input_quant can be None
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
is_static_input_scheme = (input_quant
|
||||
and not input_quant.dynamic)
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=is_static_input_scheme)
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for inference.
|
||||
"""
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
|
||||
# Find the sparsity scheme of the layer
|
||||
# assume that fused layers inerhit first component's sparsity scheme
|
||||
sparsity_targets = (self.sparsity_scheme_map.keys() -
|
||||
set(self.sparsity_ignore_list))
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=sparsity_targets,
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme):
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (None if sparsity_scheme is None
|
||||
or sparsity_scheme.format == "dense"
|
||||
else self.config)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
elif weight_quant is None:
|
||||
logger.warning_once("Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
|
||||
layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
input_quant: Optional[QuantizationArgs],
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
Conditions:
|
||||
- Overarching condition: Sparsity Structure is 2:4
|
||||
- Unquantized cases are supported
|
||||
- Weight only quantization is not-supported
|
||||
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||
- Supported input quantization strategies are TENSOR and TOKEN
|
||||
- Only 8 bit quantization is supported
|
||||
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
if sparsity_scheme is None:
|
||||
return False
|
||||
|
||||
is_valid_sparsity_structure: bool = (
|
||||
sparsity_scheme.sparsity_structure ==
|
||||
SparsityStructure.TWO_FOUR.value)
|
||||
|
||||
valid_compressors = {
|
||||
CompressionFormat.dense.value,
|
||||
CompressionFormat.sparse_24_bitmask.value
|
||||
}
|
||||
|
||||
is_valid_sparsity = (is_valid_sparsity_structure
|
||||
and sparsity_scheme.format in valid_compressors)
|
||||
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
# Unquantized cases are supported
|
||||
if weight_quant is None and input_quant is None:
|
||||
return True
|
||||
|
||||
# Weight only quantization is not-supported
|
||||
if weight_quant is not None and input_quant is None:
|
||||
return False
|
||||
|
||||
supported_weight_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.CHANNEL.value
|
||||
]
|
||||
|
||||
assert weight_quant is not None
|
||||
assert input_quant is not None
|
||||
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||
return False
|
||||
|
||||
supported_input_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
|
||||
]
|
||||
|
||||
if input_quant.strategy not in supported_input_quant_strategies:
|
||||
return False
|
||||
|
||||
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||
|
||||
|
||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: CompressedTensorsConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from compressed-tensors
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_scheme: the compressed-tensors kv cache scheme
|
||||
"""
|
||||
if kv_cache_scheme is None:
|
||||
return
|
||||
|
||||
type_ = kv_cache_scheme.get("type")
|
||||
num_bits = kv_cache_scheme.get("num_bits")
|
||||
|
||||
if type_ != "float" and num_bits != 8:
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
"num_bits=8, type=float, however "
|
||||
f"received num_bits={num_bits}, type={type_}")
|
||||
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}")
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
if not is_symmetric:
|
||||
raise NotImplementedError(
|
||||
"Only support symmetric scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16)
|
||||
|
||||
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme", "CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Fp4"
|
||||
]
|
||||
@@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, sparse_cutlass_supported)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
self.model_compressor = (
|
||||
ModelCompressor.from_compression_config(model_compression_config)
|
||||
if model_compression_config is not None else None)
|
||||
self.do_sparse_decompress = (
|
||||
self.model_compressor is not None
|
||||
and self.model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(partition_size % 8 == 0
|
||||
for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL.value):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert (self.input_quant.strategy ==
|
||||
QuantizationStrategy.TENSOR.value)
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
input_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if (layer.weight.dtype.is_floating_point
|
||||
and layer.weight.dtype.itemsize >= 2):
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = None
|
||||
if hasattr(layer, "input_scale"):
|
||||
scale = layer.input_scale
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
if scale is not None:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
|
||||
else:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=True)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [(out, layer.input_size_per_partition)
|
||||
for out in layer.logical_widths]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
))
|
||||
return decompressed
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["CompressedTensorsScheme"]
|
||||
|
||||
|
||||
class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be given when using strategy group")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
layer.scale_packed = Parameter(layer.scale_packed.data,
|
||||
requires_grad=False)
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // self.tile_size // 2,
|
||||
output_size_per_partition * self.tile_size // pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=pack_factor,
|
||||
marlin_tile_size=self.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
input_groups = (1 if self.group_size is None else
|
||||
input_size_per_partition // self.group_size)
|
||||
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if self.group_size is not None:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
meta = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", qweight)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("scale_packed", scales)
|
||||
layer.register_parameter("meta", meta)
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
|
||||
requires_grad=False)
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
scales = layer.scale_packed
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace, self.quant_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_scale_2 = Parameter(
|
||||
1 / layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
del layer.weight_global_scale
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
dequantize_to_dtype, ref_nvfp4_quant)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
self.group_size = 16
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
logger.warning("Current platform does not support cutlass NVFP4."
|
||||
" Running emulations.")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
return 80
|
||||
|
||||
def run_nvfp4_emulations(self, x: torch.Tensor, layer):
|
||||
x_m, x_k = x.shape
|
||||
output_dtype = x.dtype
|
||||
|
||||
# quantize input to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
|
||||
self.group_size)
|
||||
|
||||
# dequantize input
|
||||
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
|
||||
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
|
||||
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
|
||||
del x_fp4, x_blockscale
|
||||
|
||||
# dequantize weight
|
||||
w_fp4 = layer.weight.data.view(torch.uint8)
|
||||
w_blockscale = layer.weight_scale_swizzled.data
|
||||
w_global_scale = layer.weight_global_scale
|
||||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
|
||||
output_dtype, x.device, self.group_size)
|
||||
|
||||
# matmul
|
||||
out = torch.matmul(x_dq, w_dq.t())
|
||||
del w_dq, x_dq
|
||||
return out
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
|
||||
global_input_scale = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(global_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
|
||||
if self.cutlass_nvfp4_supported:
|
||||
layer.alpha = Parameter(layer.input_global_scale *
|
||||
layer.weight_global_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.cutlass_nvfp4_supported:
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
1 / layer.alpha, output_dtype)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
return self.run_nvfp4_emulations(x, layer)
|
||||
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [
|
||||
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
|
||||
requires_grad=False)
|
||||
else:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(),
|
||||
requires_grad=False)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}")
|
||||
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return apply_fp8_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||
input_symmetric: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8Int8",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128
|
||||
}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError("Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise.")
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric else
|
||||
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsWNA16",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition //
|
||||
self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def is_weak_contiguous(x: torch.Tensor):
|
||||
strides = x.stride()
|
||||
sizes = x.shape
|
||||
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
||||
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
||||
return is_transpose or is_not_transpose
|
||||
|
||||
|
||||
@triton.jit
|
||||
def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
|
||||
M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
|
||||
stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_B: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = ACCUMULATOR_DTYPE
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
|
||||
dtype=accumulator_dtype)
|
||||
|
||||
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
||||
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
||||
# eventually occur.
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
masks_bn = offsets_bn < N
|
||||
|
||||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
offsets_a = (stride_am * offsets_am[:, None] +
|
||||
stride_ak * offsets_k[None, :])
|
||||
offsets_b = (stride_bk * offsets_k[:, None] +
|
||||
stride_bn * offsets_bn[None, :])
|
||||
|
||||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
||||
# appropriate offsets and masks for each case. Same goes for
|
||||
# BLOCK_SIZE_SCALE_B.
|
||||
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
|
||||
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
|
||||
masks_scale_am = offsets_scale_am < M
|
||||
|
||||
offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
|
||||
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
|
||||
masks_scale_bn = offsets_scale_bn < N
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
||||
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Apply scale at end.
|
||||
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
||||
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
||||
# Need to broadcast to the appropriate size, if scale_a is already
|
||||
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
||||
# for scale_b below.
|
||||
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
||||
accumulator = scale_a * accumulator.to(tl.float32)
|
||||
|
||||
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
||||
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
||||
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
||||
accumulator = scale_b.T * accumulator.to(tl.float32)
|
||||
|
||||
# Convert to output format.
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
|
||||
# Add bias, it's already in output format, so add it after conversion.
|
||||
if bias_ptr:
|
||||
offsets_bias = offsets_bn
|
||||
bias_ptrs = bias_ptr + offsets_bias
|
||||
bias_mask = offsets_bias < N
|
||||
bias = tl.load(bias_ptrs, bias_mask)
|
||||
c += bias
|
||||
|
||||
# Save output
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
|
||||
stride_cn * offs_cn[None, :])
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# weight - [K, N]
|
||||
def triton_scaled_mm(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
use_heuristic=True) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = weight.shape[1]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert weight.shape[0] == K
|
||||
assert input.dtype == weight.dtype
|
||||
|
||||
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||
|
||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
|
||||
[M, 1])
|
||||
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
|
||||
[N, 1])
|
||||
assert out_dtype.is_floating_point
|
||||
assert bias is None or bias.is_floating_point()
|
||||
assert is_weak_contiguous(input)
|
||||
assert is_weak_contiguous(weight)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
||||
|
||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
||||
|
||||
if use_heuristic:
|
||||
is_small_N = N < 8192
|
||||
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
||||
if next_power_of_2_M <= 32:
|
||||
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
||||
elif next_power_of_2_M <= 64:
|
||||
tile_shape = (64, 64, 256)
|
||||
elif next_power_of_2_M <= 128:
|
||||
tile_shape = (64, 128, 128)
|
||||
else:
|
||||
tile_shape = (128, 128, 128)
|
||||
|
||||
block_size_m, block_size_n, block_size_k = tile_shape
|
||||
|
||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
||||
|
||||
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
||||
|
||||
# A = input, B = weight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
scaled_mm_kernel[grid](input,
|
||||
weight,
|
||||
scale_a,
|
||||
scale_b,
|
||||
result,
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
weight.stride(0),
|
||||
weight.stride(1),
|
||||
result.stride(0),
|
||||
result.stride(1),
|
||||
accumulator_dtype,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_SCALE_A=block_size_sa,
|
||||
BLOCK_SIZE_SCALE_B=block_size_sb)
|
||||
|
||||
return result.to(out_dtype)
|
||||
216
model_executor/layers/quantization/compressed_tensors/utils.py
Normal file
216
model_executor/layers/quantization/compressed_tensors/utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.nvfp4_pack_quantized.value
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config.")
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(value: str,
|
||||
targets: Iterable[str],
|
||||
check_contains: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value,
|
||||
target,
|
||||
check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str, target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
|
||||
None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: list[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
195
model_executor/layers/quantization/deepspeedfp.py
Normal file
195
model_executor/layers/quantization/deepspeedfp.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class DeepSpeedFPConfig(QuantizationConfig):
|
||||
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
|
||||
|
||||
Args:
|
||||
weight_bits: the target quantization bits, 6 or 8.
|
||||
group_size: group size for quantizaiton, default to 128.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int = 8,
|
||||
group_size: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.valid_types = [torch.bfloat16, torch.float16]
|
||||
|
||||
if self.weight_bits not in (6, 8):
|
||||
raise ValueError(
|
||||
"Currently, only 6-bit or 8-bit weight quantization are "
|
||||
f"supported for DeepSpeed FP quantizaiton, but got "
|
||||
f"{self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
|
||||
f"group_size={self.group_size}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits=weight_bits, group_size=group_size)
|
||||
|
||||
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class DeepSpeedFPLinearMethod(LinearMethodBase):
|
||||
"""Linear method for DeepSpeedFP quantizer.
|
||||
|
||||
Args:
|
||||
quant_config: the DeepSpeedFP quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: DeepSpeedFPConfig):
|
||||
self.quant_config = quant_config
|
||||
self.weight = None
|
||||
|
||||
def create_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader=None,
|
||||
**extra_weight_attrs):
|
||||
del output_size
|
||||
del input_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = DeepSpeedFPParameter(
|
||||
torch.Size((output_size_per_partition, input_size_per_partition)),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def quant_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# Calls the original weight loader (if any), quantizes the result,
|
||||
# and then loads the quantized parameter.
|
||||
if weight_loader is not None:
|
||||
orig_param_data = param.data
|
||||
param.data = param.ds_dequantize()
|
||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||
param.data, loaded_weight = orig_param_data, param.data
|
||||
param.ds_quantize_(loaded_weight.cuda())
|
||||
|
||||
extra_weight_attrs["weight_loader"] = quant_weight_loader
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
y = weight.ds_dequantize()
|
||||
return F.linear(x, y, bias)
|
||||
|
||||
|
||||
class DeepSpeedFPParameter(nn.Parameter):
|
||||
"""
|
||||
DeepSpeedFP quantized parameter class that implements fp8/fp6
|
||||
quantization deepspeed. Weights are stored in quantized form on
|
||||
GPUs, and can be dequantized on-the-fly when needed by the model.
|
||||
"""
|
||||
|
||||
def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
|
||||
quant_config: DeepSpeedFPConfig):
|
||||
try:
|
||||
import deepspeed
|
||||
if deepspeed.__version__ < "0.14.2":
|
||||
raise ImportError("deepspeed version is wrong. Please "
|
||||
"install deepspeed>=0.14.2.")
|
||||
from deepspeed.ops.fp_quantizer import FP_Quantize
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install deepspeed>=0.14.2 via "
|
||||
"`pip install deepspeed>=0.14.2` to use "
|
||||
"deepspeedfp quantizer.") from err
|
||||
data = torch.empty((
|
||||
orig_shape.numel() // quant_config.group_size,
|
||||
quant_config.group_size * quant_config.weight_bits // 8 + 4,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
self.orig_shape = orig_shape
|
||||
self.quant_config = quant_config
|
||||
self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
|
||||
self.fp_quantizer.orig_shape = orig_shape
|
||||
self.fp_quantizer.orig_dtype = params_dtype
|
||||
return self
|
||||
|
||||
def ds_quantize_(self, tensor: torch.Tensor):
|
||||
assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
|
||||
return self.data.copy_(
|
||||
self.fp_quantizer.quantize(
|
||||
tensor.data,
|
||||
q_bits=self.quant_config.weight_bits,
|
||||
))
|
||||
|
||||
def ds_dequantize(self, fp_out=None) -> torch.Tensor:
|
||||
"""
|
||||
Return a tensor containing the dequantized weights of this parameter.
|
||||
"""
|
||||
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
|
||||
return self.fp_quantizer.dequantize(
|
||||
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
|
||||
|
||||
def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
|
||||
"""
|
||||
Return a tensor where only the weights at `indices` are dequantized
|
||||
(to save HBM -> SRAM bandwidth).
|
||||
"""
|
||||
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
|
||||
return self.fp_quantizer.selective_dequantize(
|
||||
self.data,
|
||||
indices,
|
||||
fp_out=fp_out,
|
||||
q_bits=self.quant_config.weight_bits)
|
||||
196
model_executor/layers/quantization/experts_int8.py
Normal file
196
model_executor/layers/quantization/experts_int8.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class ExpertsInt8Config(QuantizationConfig):
|
||||
"""Config class for Int8 experts quantization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ExpertsInt8MoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ExpertsInt8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
int8_dtype = torch.int8
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int8_w8a16=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
|
||||
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: int,
|
||||
expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
device = get_tp_group().device
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
# w1, gate_proj case: Load into first shard of w13.
|
||||
if shard_id == "w1":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
|
||||
0])
|
||||
# w3, up_proj case: Load into second shard of w13.
|
||||
elif shard_id == "w3":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, shard_size:2 *
|
||||
shard_size].copy_(scales[:, 0])
|
||||
# w2, down_proj case: Load into only shard of w2.
|
||||
elif shard_id == "w2":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[:,
|
||||
shard])
|
||||
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||
expert_id)
|
||||
|
||||
return quantize_and_call_weight_loader
|
||||
|
||||
|
||||
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
|
||||
vmax = torch.iinfo(torch.int8).max
|
||||
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
|
||||
|
||||
weight.div_(scales)
|
||||
weight.round_()
|
||||
weight.clamp_(-vmax, vmax)
|
||||
|
||||
return scales
|
||||
172
model_executor/layers/quantization/fbgemm_fp8.py
Normal file
172
model_executor/layers/quantization/fbgemm_fp8.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = not current_platform.has_device_capability(89)
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignore_list,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE UPPER BOUND
|
||||
input_scale_ub = torch.nn.Parameter(torch.tensor(
|
||||
(self.quant_config.input_scale_ub), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.input_scale_ub = input_scale_ub
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale_ub
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=None,
|
||||
input_scale_ub=layer.input_scale_ub,
|
||||
bias=bias)
|
||||
906
model_executor/layers/quantization/fp8.py
Normal file
906
model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,906 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
weight_block_size: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
self.ignored_layers = ignored_layers or []
|
||||
if weight_block_size is not None:
|
||||
if not is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"The block-wise quantization only supports fp8-serialized "
|
||||
"checkpoint for now.")
|
||||
if len(weight_block_size) != 2:
|
||||
raise ValueError(
|
||||
"The quantization block size of weight must have 2 "
|
||||
f"dimensions, but got {len(weight_block_size)} dimensions")
|
||||
if activation_scheme != "dynamic":
|
||||
raise ValueError("The block-wise quantization only supports "
|
||||
"dynamic activation scheme for now, but got "
|
||||
f"{activation_scheme} activation scheme.")
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
|
||||
None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
2. Only support float8_e4m3fn data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = (current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
# Default to using per_token quantization if cutlass is supported
|
||||
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.block_quant:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
if (tp_size > 1 and output_size // output_size_per_partition
|
||||
== tp_size) or len(output_partition_sizes) > 1:
|
||||
for output_partition_size in output_partition_sizes:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
size_k_first = True
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
size_k_first = False
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm:
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
" DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
layer.weight_block_size = self.quant_config.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and intermediate_size_per_partition % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if not self.block_quant:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
else:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
|
||||
value} if self.block_quant else
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_fp8_fnuz():
|
||||
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale_inv,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale_inv, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
w2_weight = layer.w2_weight
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device),
|
||||
requires_grad=False)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight.data[expert, :, :])
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if (layer.w13_input_scale is None
|
||||
or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer.")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize, moe):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
|
||||
TritonOrDeepGemmExperts]] = None
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
use_batched_experts = max_num_tokens_per_rank is not None
|
||||
|
||||
if use_batched_experts:
|
||||
experts = BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
world_size=prepare_finalize.world_size,
|
||||
dp_size=prepare_finalize.dp_size,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
else:
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
assert experts is not None
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
565
model_executor/layers/quantization/gguf.py
Normal file
565
model_executor/layers/quantization/gguf.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GGUFMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
WeightType.Q4_1,
|
||||
WeightType.Q5_0,
|
||||
WeightType.Q5_1,
|
||||
WeightType.Q8_0,
|
||||
WeightType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
WeightType.Q2_K,
|
||||
WeightType.Q3_K,
|
||||
WeightType.Q4_K,
|
||||
WeightType.Q5_K,
|
||||
WeightType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
WeightType.IQ1_M,
|
||||
WeightType.IQ1_S,
|
||||
WeightType.IQ2_XXS,
|
||||
WeightType.IQ2_XS,
|
||||
WeightType.IQ2_S,
|
||||
WeightType.IQ3_XXS,
|
||||
WeightType.IQ3_S,
|
||||
WeightType.IQ4_XS,
|
||||
WeightType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
return torch.empty(x.shape[0],
|
||||
qweight.shape[0],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0],
|
||||
qweight.shape[0],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def act(x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(out, x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
return out
|
||||
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
moe_align_block_size)
|
||||
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
# unless we decent expert reuse we are better off running moe_vec kernel
|
||||
if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
|
||||
and x.shape[0] > 64):
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
|
||||
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type, N, top_k,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type2,
|
||||
w2.shape[1], 1, num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
|
||||
out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
|
||||
w2.shape[1], num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once("There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
"Falling back to slow implementation. ")
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = fused_mul_mat_gguf(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
out_hidden_states[tok] = current_hidden_state
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def _fused_moe_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _apply_gguf_embedding(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return torch.embedding(qweight, x)
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
x_flat = x.flatten()
|
||||
assert (hidden_size == qweight.shape[1] // type_size * block_size)
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0], dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
else:
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
|
||||
|
||||
def _apply_gguf_embedding_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
mutates_args=[],
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
})
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"shard_weight_type": {},
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("qweight_type", qweight_type)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
if not (qweight_type in UNQUANTIZED_TYPES
|
||||
or qweight_type in DEQUANT_TYPES):
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise ValueError(
|
||||
f"Unsupported GGUF quantization type {qweight_type} in "
|
||||
f"layer {layer}.")
|
||||
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
|
||||
# materialize the padded weight parameter for CUDA Graph compatibility.
|
||||
self._create_padded_weight_param(layer)
|
||||
|
||||
def _create_padded_weight_param(self, layer: torch.nn.Module):
|
||||
"""Create padded weight parameter for GGUF MergedLinear layer."""
|
||||
qweight = layer.qweight
|
||||
shard_id_map = qweight.shard_id_map
|
||||
shard_id = qweight.shard_id
|
||||
if len(data_container := qweight.data_container) > 1:
|
||||
dtype = {data.dtype for data in data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}")
|
||||
dtype = next(iter(dtype))
|
||||
# concat dim0 and pad dim1
|
||||
padded_side = max(x.size(1) for x in data_container)
|
||||
concat_side = sum(x.size(0) for x in data_container)
|
||||
# Pad the quantized weights to dense tensor, and create a map
|
||||
# with the location of each shard in the padded tensor.
|
||||
padded_data = torch.zeros((concat_side, padded_side),
|
||||
dtype=dtype,
|
||||
device=qweight.device)
|
||||
# (dim0_start, dim0_end, dim1_size)
|
||||
shard_offset_map = dict[str, tuple[int, int, int]]()
|
||||
for idx in shard_id:
|
||||
id_in_container = shard_id_map[idx]
|
||||
start = sum(
|
||||
x.size(0) for x in data_container[:id_in_container])
|
||||
end = start + data_container[id_in_container].size(0)
|
||||
size = data_container[id_in_container].size(1)
|
||||
padded_data[start:end, :size] = data_container[id_in_container]
|
||||
shard_offset_map[idx] = (start, end, size)
|
||||
qweight.data_container.clear()
|
||||
padded_param = Parameter(padded_data, requires_grad=False)
|
||||
set_weight_attrs(padded_param, vars(qweight))
|
||||
set_weight_attrs(padded_param,
|
||||
{"shard_offset_map": shard_offset_map})
|
||||
layer.register_parameter("qweight", padded_param)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
shard_id = layer.qweight.shard_id
|
||||
|
||||
if shard_id:
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight
|
||||
result = []
|
||||
for idx in shard_id:
|
||||
start, end, offset = layer.qweight.shard_offset_map[idx]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
||||
result.append(
|
||||
fused_mul_mat_gguf(
|
||||
x, qweight[start:end, :offset].contiguous(),
|
||||
qweight_type))
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
out = fused_mul_mat_gguf(x, qweight, qweight_type)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
|
||||
|
||||
class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate up proj
|
||||
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
|
||||
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w13_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
||||
|
||||
tensor_shape = (num_experts, intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate down proj
|
||||
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
|
||||
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w2_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, activation)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""Embedding method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
return apply_gguf_embedding(x,
|
||||
qweight,
|
||||
qweight_type,
|
||||
hidden_size,
|
||||
dtype=self.params_dtype)
|
||||
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: list[torch.Tensor]
|
||||
351
model_executor/layers/quantization/gptq.py
Normal file
351
model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if self.quant_config.group_size == 128 or self.quant_config.group_size == 64:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
if layer.scales.dtype != torch.bfloat16:
|
||||
perm_space = torch.empty(0)
|
||||
temp_space = torch.empty(0)
|
||||
if self.quant_config.weight_bits == 4:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space, temp_space,
|
||||
False)
|
||||
if self.quant_config.weight_bits == 8:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space, temp_space,
|
||||
False)
|
||||
else:
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
"""
|
||||
perm_space = torch.empty(0)
|
||||
if self.quant_config.weight_bits == 4:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space)
|
||||
if self.quant_config.weight_bits == 8:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space)
|
||||
"""
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
perm_space = torch.empty(0)
|
||||
temp_space = torch.empty(0)
|
||||
if self.quant_config.weight_bits == 4 or self.quant_config.weight_bits == 8:
|
||||
if self.quant_config.group_size == 128 or self.quant_config.group_size == 64:
|
||||
if self.quant_config.desc_act:
|
||||
perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1],
|
||||
dtype=torch.float16, device="cuda")
|
||||
|
||||
if reshaped_x.dtype == torch.bfloat16:
|
||||
temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1],
|
||||
dtype=torch.float32, device="cuda")
|
||||
|
||||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space, temp_space,
|
||||
True if reshaped_x.dtype == torch.bfloat16 else False)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
445
model_executor/layers/quantization/gptq_bitblas.py
Normal file
445
model_executor/layers/quantization/gptq_bitblas.py
Normal file
@@ -0,0 +1,445 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
BitBLASLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
|
||||
check_bitblas_supported, verify_bitblas_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQBitBLASConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ BitBLAS"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
TORCH_DTYPE = torch.float16
|
||||
GPTQ_CKPT_STORAGE_DTYPE = (
|
||||
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
|
||||
)
|
||||
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
|
||||
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# the gptq_bitblas prefer "quantized"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
|
||||
if c.isdigit()))
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})"
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||
or user_quant == "gptq_bitblas")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info("Detected that the model can run with gptq_bitblas"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_bitblas for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQBitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
@property
|
||||
def torch_storage_dtype(self) -> torch.dtype:
|
||||
return self.TORCH_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
@classmethod
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
# temporarily disable on ROCm platform
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
# Otherwise, can convert if model satisfies bitblas constraints.
|
||||
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
|
||||
sym)],
|
||||
group_size=group_size)
|
||||
|
||||
|
||||
class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ BitBLAS quantization config.
|
||||
"""
|
||||
|
||||
kernel_type = BitBLASLinearKernel
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
# Verify supported on platform.
|
||||
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing
|
||||
quantized weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_partition_sizes: The size of the output partition.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or
|
||||
if the input size per partition is not divisible by the
|
||||
group size in `quant_config`.
|
||||
"""
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({self.quant_config.group_size})."
|
||||
)
|
||||
|
||||
kernel_type = self.kernel_type
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQBitBLASLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Init buffers
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Activation order
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Scales
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"input_dim": scales_and_zp_input_dim,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Quantized zero-points
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx",
|
||||
bitblas_quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self.kernel.configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
648
model_executor/layers/quantization/gptq_marlin.py
Normal file
648
model_executor/layers/quantization/gptq_marlin.py
Normal file
@@ -0,0 +1,648 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_make_workspace_new, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized, dynamic, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "gptq_marlin")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info("Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQMarlinLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError(
|
||||
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size_full)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
w2_scales_size = (intermediate_size_full
|
||||
if self.quant_config.desc_act else
|
||||
intermediate_size_per_partition)
|
||||
scales_size2 = (w2_scales_size // self.quant_config.group_size)
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": True
|
||||
})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_scales,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_qzeros,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(
|
||||
layer.w13_g_idx[e]).to(torch.int32)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
|
||||
w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||
w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1] *
|
||||
(self.quant_config.group_size if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
297
model_executor/layers/quantization/gptq_marlin_24.py
Normal file
297
model_executor/layers/quantization/gptq_marlin_24.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
|
||||
scalar_types.uint4b8, scalar_types.uint8b128
|
||||
]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if quant_type is None or \
|
||||
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
is_marlin_24_format = (
|
||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_marlin_24")
|
||||
|
||||
if is_marlin_24_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. "
|
||||
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin24 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||
if (self.quant_config.group_size != -1 and
|
||||
input_size_per_partition % self.quant_config.group_size != 0):
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}.")
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Meta
|
||||
meta = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
device="cuda",
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||
input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("B_24", qweight)
|
||||
layer.register_parameter("B_meta", meta)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace,
|
||||
self.quant_config.quant_type,
|
||||
size_m, size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
332
model_executor/layers/quantization/hqq_marlin.py
Normal file
332
model_executor/layers/quantization/hqq_marlin.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
marlin_make_empty_g_idx, marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HQQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for HQQ Marlin"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
"currently 64.")
|
||||
assert weight_bits == 4, ("The only supported HQQ quantization "
|
||||
"bitsize is currently 4.")
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.skip_modules = skip_modules
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
|
||||
wq_params = (config["quant_config"]["weight_quant_params"])
|
||||
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
||||
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
||||
skip_modules = config["skip_modules"]
|
||||
return cls(weight_bits, group_size, skip_modules)
|
||||
|
||||
def is_layer_skipped(self, prefix: str) -> bool:
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
return self.skip_modules is not None and any(
|
||||
module_name in components for module_name in self.skip_modules)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return HQQMarlinMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
# Empty HQQ parameter, will be ignored during loading
|
||||
class HQQEmptyParameter(BasevLLMParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
pass
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
raise ValueError("No loader provided for HQQ parameter!")
|
||||
|
||||
|
||||
# HQQ packing creates issues with sharding - therefore, prior to loading, we
|
||||
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
|
||||
class HQQweightParameter(PackedvLLMParameter):
|
||||
|
||||
# unpack function from https://github.com/mobiusml/hqq
|
||||
def unpack_4bit_u8(self,
|
||||
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
|
||||
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
|
||||
|
||||
dtype = torch.uint8
|
||||
step = W_q.shape[0]
|
||||
tmp = torch.empty([2 * step, W_q.shape[1]],
|
||||
dtype=dtype,
|
||||
device=W_q.device)
|
||||
tmp[:step] = (W_q & 0b11110000) >> 4
|
||||
tmp[step:] = W_q & 0b00001111
|
||||
return tmp
|
||||
|
||||
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
|
||||
**kwargs):
|
||||
super().__init__(packed_factor, packed_dim, None, **kwargs)
|
||||
self.weight_bits = weight_bits
|
||||
self.input_shape = self.shape[self.input_dim] * self.packed_factor
|
||||
self.output_shape = self.shape[self.output_dim]
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(self.output_shape,
|
||||
-1).transpose(1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
|
||||
# GPTQ shape (transposed - we transpose them too when processing weights).
|
||||
class HQQZeroScaleParameter(GroupQuantScaleParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
class HQQMarlinMethod(LinearMethodBase):
|
||||
"""Linear method for HQQ Marlin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: HQQMarlinConfig,
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
|
||||
|
||||
self.scales_and_zp_size = (input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
qweight = HQQweightParameter(
|
||||
data=torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
zeros = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("W_q", qweight)
|
||||
layer.register_parameter("zero", zeros)
|
||||
layer.register_parameter("scale", scales)
|
||||
|
||||
# Ignore extra parameters in the HQQ model.
|
||||
# To be added as needed.
|
||||
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
|
||||
"encoded_state_dict", "group_size", "nbits",
|
||||
"offload_meta", "optimize", "packing",
|
||||
"quant_scale", "quant_zero", "round_zero",
|
||||
"shape", "stores_quant_config",
|
||||
"unpack_view_dtype", "view_as_float")
|
||||
for name in ignore_parameters:
|
||||
layer.register_parameter(
|
||||
name,
|
||||
HQQEmptyParameter(data=torch.empty(0),
|
||||
weight_loader=weight_loader))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
dev = layer.W_q.device
|
||||
|
||||
# Repack to Marlin
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(
|
||||
layer.W_q,
|
||||
sort_indices,
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.weight_bits,
|
||||
).to(dev)
|
||||
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
|
||||
layer.g_idx = marlin_make_empty_g_idx(dev)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
layer.marlin_qweight = marlin_w_q
|
||||
layer.marlin_zeros = marlin_zp
|
||||
layer.marlin_scales = marlin_s
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
workspace = MarlinWorkspace(self.output_size_per_partition,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
scales = layer.marlin_scales
|
||||
zeros = layer.marlin_zeros
|
||||
orig_type = x.dtype
|
||||
|
||||
if orig_type != torch.float16:
|
||||
x = x.to(torch.float16)
|
||||
scales = scales.to(torch.float16)
|
||||
zeros = zeros.to(torch.float16)
|
||||
|
||||
marlin_out = ops.gptq_marlin_gemm(
|
||||
x,
|
||||
None,
|
||||
layer.marlin_qweight,
|
||||
scales,
|
||||
None,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
scalar_types.uint4,
|
||||
x.shape[0],
|
||||
self.output_size_per_partition,
|
||||
self.input_size_per_partition,
|
||||
True, # is_k_full
|
||||
False, # use atomic add
|
||||
True, # use 32-bit reduce
|
||||
True, # use float zp
|
||||
)
|
||||
|
||||
if orig_type != torch.float16:
|
||||
marlin_out = marlin_out.to(orig_type)
|
||||
|
||||
if bias is not None:
|
||||
marlin_out.add_(bias)
|
||||
|
||||
return marlin_out
|
||||
250
model_executor/layers/quantization/ipex_quant.py
Normal file
250
model_executor/layers/quantization/ipex_quant.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MIN_IPEX_VERSION = "2.7.0"
|
||||
|
||||
|
||||
class IPEXConfig(QuantizationConfig):
|
||||
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
|
||||
including AWQ, GPTQ.
|
||||
"""
|
||||
|
||||
IPEX_QUANT_METHOD_MAP = {
|
||||
"awq": 1,
|
||||
"gptq": 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.method = method
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
if self.weight_bits not in [4]:
|
||||
raise ValueError(f"IPEX quantization supports weight bits [4], "
|
||||
f"but got {self.weight_bits}.")
|
||||
|
||||
if self.method not in ["awq", "gptq"]:
|
||||
raise ValueError(f"IPEX quantization supports [awq, gptq], "
|
||||
f"but got {self.method}.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"IPEXConfig(method={self.method},"
|
||||
f"weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
|
||||
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||
if method == "awq":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config,
|
||||
["q_group_size", "group_size"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(method, weight_bits, group_size, modules_to_not_convert,
|
||||
False, False)
|
||||
# otherwise for gptq
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
|
||||
return cls(method, weight_bits, group_size, [], desc_act,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
if quant_method in ["awq", "gptq"]:
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.method == "awq":
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return IPEXAWQLinearMethod(self)
|
||||
if self.method == "gptq":
|
||||
return IPEXGPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class IPEXGPTQLinearMethod(GPTQLinearMethod):
|
||||
"""GPTQ linear method using IPEX for the CPU/XPU backend.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if ipex.__version__ < MIN_IPEX_VERSION:
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method.") from err
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
|
||||
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
layer.ipex_output_size = layer.qweight.shape[-1]
|
||||
g_idx = layer.g_idx if self.quant_config.desc_act else None
|
||||
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
|
||||
IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
g_idx=g_idx,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
|
||||
)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||
|
||||
|
||||
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
"""AWQ linear method using IPEX for the CPU/XPU backend.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer=layer)
|
||||
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if ipex.__version__ < MIN_IPEX_VERSION:
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method.") from err
|
||||
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
|
||||
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
layer.ipex_output_size = layer.qweight.size(
|
||||
1) * self.quant_config.pack_factor
|
||||
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
|
||||
IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
|
||||
)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
if c.zero_points:
|
||||
assert w_zp_param_name is not None
|
||||
if c.has_g_idx:
|
||||
assert w_gidx_param_name is not None
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
|
||||
fn: Callable) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name,
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor] # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
||||
BitBLASLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
MacheteLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
|
||||
MarlinLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
|
||||
MPLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
BitBLASLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
]
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the compute
|
||||
capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
|
||||
if kernel.get_min_capability() > compute_capability:
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"WNA16 linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args['sm_count'] = sm_count
|
||||
gemm_args['sm_version'] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(old_weight_param.data,
|
||||
requires_grad=False)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous().view(
|
||||
dtype=torch.uint8)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data,
|
||||
requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = \
|
||||
ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None,
|
||||
c.zero_points)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args['sm_count'],
|
||||
sm_version=gemm_args['sm_version'],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,300 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
|
||||
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
|
||||
unpack_gptq_qweight, unpack_gptq_qzeros)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
BITBLAS_DTYPES: dict[torch.dtype, str] = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
bitblas_matmul: object = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
|
||||
w_gidx_param_name)
|
||||
|
||||
def repack_bitblas_from_gptq(
|
||||
self,
|
||||
b_q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
|
||||
|
||||
quant_config = self.quant_config
|
||||
# qweight in gptq old quant linear stored with
|
||||
# (outfeatures, infeatures), should be transposed.
|
||||
qweight = b_q_weight.T.contiguous().view(
|
||||
quant_config.torch_storage_dtype) # type: ignore[union-attr]
|
||||
intweight = unpack_gptq_qweight(
|
||||
qweight,
|
||||
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
|
||||
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
|
||||
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
|
||||
intweight.cpu()).cuda()
|
||||
# scales in gptq old quant linear stored with
|
||||
# (infeatures // group_size, outfeatures), should be transposed.
|
||||
scales = scales.T.contiguous()
|
||||
|
||||
if qzeros is None:
|
||||
return qweight, scales, None
|
||||
|
||||
# qzeros should be de-quantized to int zeros.
|
||||
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
|
||||
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
|
||||
zeros: Optional[torch.Tensor] = None
|
||||
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
|
||||
if zeros_mode == "original":
|
||||
zeros = intzeros.to(torch.float16).contiguous()
|
||||
elif zeros_mode == "rescale":
|
||||
assert zeros is not None, "zeros should not be None"
|
||||
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
|
||||
elif zeros_mode == "quantized":
|
||||
zeros = (
|
||||
torch.Tensor(
|
||||
general_compress(
|
||||
intzeros.T.contiguous().cpu().numpy(),
|
||||
weight_bits,
|
||||
)).to(qweight.device).
|
||||
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
|
||||
).contiguous())
|
||||
else:
|
||||
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
|
||||
|
||||
return qweight, scales, zeros
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError:
|
||||
is_bitblas_installed = False
|
||||
|
||||
if not is_bitblas_installed:
|
||||
return False, "bitblas is not installed. Please install bitblas "\
|
||||
"by running `pip install bitblas>="\
|
||||
f"{MINIMUM_BITBLAS_VERSION}`"
|
||||
|
||||
quant_types = query_bitblas_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, (f"Quant type ({c.weight_type}) not supported by"
|
||||
f" BitBLAS, supported types are: {quant_types}")
|
||||
|
||||
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
|
||||
return False, (f"Group size ({c.group_size}) not supported by "
|
||||
"BitBLAS, supported group sizes are: "
|
||||
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
|
||||
|
||||
return check_bitblas_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
quant_config = self.quant_config
|
||||
|
||||
# Default names since bitblas requires empty parameters for these,
|
||||
# TODO: remove this requirement from bitblas (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "qzeros"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
raise NotImplementedError("Zero points not supported by BitBLAS")
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
|
||||
|
||||
# Repack weights
|
||||
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
|
||||
self.repack_bitblas_from_gptq(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None if quant_config.is_sym else # type: ignore[union-attr]
|
||||
layer.qzeros, # type: ignore[union-attr]
|
||||
))
|
||||
replace_parameter(layer, self.w_q_name, bitblas_qweight)
|
||||
replace_parameter(layer, self.w_s_name, bitblas_scales)
|
||||
if bitblas_qzeros is not None:
|
||||
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
|
||||
|
||||
def configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures: int,
|
||||
outfeatures: int,
|
||||
params_dtype: torch.dtype,
|
||||
bias: bool,
|
||||
) -> None:
|
||||
enable_tuning = self.ENABLE_TUNING
|
||||
layout = self.MATMUL_LAYOUT
|
||||
bits = self.quant_config.weight_bits # type: ignore[union-attr]
|
||||
self._configure_bitblas_matmul(
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
)
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
quant_config = self.quant_config
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = quant_config.group_size # type: ignore[union-attr]
|
||||
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
|
||||
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if quant_config.is_sym: # type: ignore[union-attr]
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
M=self.OPT_FEATURES,
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=bitblas_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=quant_config. # type: ignore[union-attr]
|
||||
storage_dtype, # type: ignore[union-attr]
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNING_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created without tuning. "
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} retrieved from cache."
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq_bitblas_linear(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_size_per_partition = self.config.partition_weight_shape[1]
|
||||
out_shape = x.shape[:-1] + (output_size_per_partition, )
|
||||
args = [x, layer.qweight, layer.scales]
|
||||
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
|
||||
args.append(layer.qzeros)
|
||||
output = self.bitblas_matmul(*args) # type: ignore[operator]
|
||||
return output.view(out_shape)
|
||||
|
||||
def apply_weights(self, layer, x, bias=None):
|
||||
NOT_IMPLEMENT_MESSAGE = (
|
||||
f"{self.__class__.__name__}.apply_weights is not implemented. "
|
||||
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
|
||||
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
|
||||
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class ExllamaLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||
# currently untested so not added to the list
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Exllama, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return False, "Output features must be a multiple of the pack " \
|
||||
"factor (32 / num_bits) so that we can correctly " \
|
||||
"pack the zero points"
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Exllama, supported types are: "\
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}"
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# if the type has a bias we have to create a zeros tensor that
|
||||
# contains the bias values repeated for each group (-1 due to
|
||||
# a bug in the original GPTQ checkpoint format leading to
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full((groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference")
|
||||
zeros = pack_quantized_values_into_int32(zeros,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
setattr(layer, self.w_zp_name,
|
||||
torch.nn.Parameter(zeros, requires_grad=False))
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
# Exllama wants the permutation array instead of the group
|
||||
# indices
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
|
||||
c.weight_type.size_bits)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
||||
query_machete_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by Machete"
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Machete, supported types are: "\
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||
|
||||
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Machete, supported group sizes are: "\
|
||||
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
|
||||
|
||||
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
|
||||
.to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if c.act_type in [torch.float16, torch.bfloat16] \
|
||||
and c.partition_weight_shape[0] % 8 == 0:
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(x_perm,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=None,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by"\
|
||||
f" Marlin, supported types are: {quant_types}"
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Marlin, supported group sizes are: "\
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
marlin_zero_points(
|
||||
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1]),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
is_channelwise: bool
|
||||
is_static_input_scheme: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||
w_s_param_name: str, i_s_param_name: str,
|
||||
i_zp_param_name: str, azp_adj_param_name: str) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
self.i_s_name = i_s_param_name
|
||||
self.i_zp_name = i_zp_param_name
|
||||
self.azp_adj_name = azp_adj_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
Optional[torch.Tensor], # input_zp
|
||||
Optional[torch.Tensor], # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.i_s_name),
|
||||
getattr(layer, self.i_zp_name),
|
||||
getattr(layer, self.azp_adj_name),
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
to be implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||
.split(","):
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"ScaledMM linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
120
model_executor/layers/quantization/kernels/scaled_mm/aiter.py
Normal file
120
model_executor/layers/quantization/kernels/scaled_mm/aiter.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"currently supported on non-ROCm platform.")
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"installed on ROCm.")
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (
|
||||
envs.VLLM_ROCM_USE_AITER_LINEAR \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
):
|
||||
return (False, "AiterScaledMMLinearKernel is disabled. " +
|
||||
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
|
||||
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"AiterScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
`AiterScaledMMLinearKernel` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
Currently only support per-tensor-per-tensor GEMM
|
||||
and per-token-per-channel GEMM through AITER
|
||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
assert symmetric, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
out_dtype = x.dtype
|
||||
|
||||
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == w_q.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
|
||||
m = x_q.shape[0] # a
|
||||
n = w_q.shape[1] # b
|
||||
|
||||
per_tensor_scale_a = (x_s.numel() == 1)
|
||||
per_tensor_scale_b = (w_s.numel() == 1)
|
||||
per_token_scale_a = (x_s.numel() == m)
|
||||
per_channel_scale_b = (w_s.numel() == n)
|
||||
|
||||
# @TODO:
|
||||
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||
# if one of the scale is a per-channel-scale.
|
||||
# For now, it only supports:
|
||||
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||
# - per-token-per-channel a8w8 scaled GEMM
|
||||
assert ((per_tensor_scale_a and per_tensor_scale_b)
|
||||
or (per_token_scale_a and per_channel_scale_b)), (
|
||||
"Currently only support per-tensor-per-tensor GEMM " +
|
||||
" and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
|
||||
"does not support AITER block scaled GEMM.")
|
||||
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
|
||||
137
model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
Normal file
137
model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
|
||||
return False, "CutlassScaledMM requires running on CUDA or CPU."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = i_zp is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel requires Triton which is not " +
|
||||
"currently supported on CPU.")
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
105
model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
105
model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"TPU platform does have a concept of compute capability, "
|
||||
"this method should not be called.")
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
if c.is_static_input_scheme:
|
||||
return False, "ScaledMMXLA requires dynamic activation scales."
|
||||
|
||||
if not c.input_symmetric:
|
||||
return False, "ScaledMMXLA requires symmetric activation scales."
|
||||
|
||||
if not c.is_channelwise:
|
||||
return False, "ScaledMMXLA requires channelwise weight scales"
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# XLA kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
# [out_channel,] (different than cutlass_scaled_mm)
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
# Filter warning for cond usage in apply_weights. It is okay
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=
|
||||
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501
|
||||
)
|
||||
|
||||
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
return x
|
||||
|
||||
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
return x + bias
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
||||
out = torch.ops.xla.quantized_matmul(x,
|
||||
w_q,
|
||||
w_s,
|
||||
zero_point=None,
|
||||
block_size=-1,
|
||||
int4_weight=False,
|
||||
quantize_activation=True)
|
||||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
||||
out = out.to(x.dtype)
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
139
model_executor/layers/quantization/kv_cache.py
Normal file
139
model_executor/layers/quantization/kv_cache.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
- dequantize k/v_cache entries before fetching them from the cache
|
||||
|
||||
:param quant_config: the appropriate QuantizationConfig
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuantizationConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Create "weight" (aka q_scale, k_scale and v_scale)
|
||||
for an attention layer.
|
||||
"""
|
||||
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
# Initialize P = softmax(QK^T) scales
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(
|
||||
v_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
|
||||
if layer.q_scale < 0.0:
|
||||
logger.warning_once(
|
||||
"Checkpoint does not provide a q scaling factor. "
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if (k_scale == 1.0 and v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
logger.warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
|
||||
if layer.q_scale > 0.0:
|
||||
q_scale = layer.q_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
q_scale *= 2
|
||||
layer.calculate_kv_scales = False
|
||||
else:
|
||||
q_scale = 1.0
|
||||
if layer.prob_scale > 0.0:
|
||||
prob_scale = layer.prob_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
prob_scale *= 2
|
||||
else:
|
||||
prob_scale = 1.0
|
||||
|
||||
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
|
||||
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
|
||||
if not is_singleton_float(q_scale) or not is_singleton_float(
|
||||
prob_scale):
|
||||
raise ValueError("Only support per-tensor scaling factor"
|
||||
"for fp8-quantized Q/prob")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||
or prob_scale == 1.0):
|
||||
logger.warning_once(
|
||||
f"Using uncalibrated q_scale {q_scale} and/or prob_scale "
|
||||
f"{prob_scale} with fp8 attention. This may cause accuracy "
|
||||
"issues. Please make sure q/prob scaling factors are "
|
||||
"available in the fp8 checkpoint.")
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.prob_scale
|
||||
261
model_executor/layers/quantization/marlin.py
Normal file
261
model_executor/layers/quantization/marlin.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin, but got group_size of "
|
||||
f"{self.group_size}")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
||||
or hf_quant_cfg.get("is_marlin_format", False))
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "marlin")
|
||||
|
||||
if is_marlin_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. Using {} kernel.".
|
||||
format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["MarlinLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class MarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||
if (self.quant_config.group_size != -1 and
|
||||
input_size_per_partition % self.quant_config.group_size != 0):
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}.")
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||
input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("B", qweight)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B = Parameter(layer.B.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
737
model_executor/layers/quantization/modelopt.py
Normal file
737
model_executor/layers/quantization/modelopt.py
Normal file
@@ -0,0 +1,737 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Config class for ModelOpt FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
" the format is experimental and could change.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
|
||||
" quantizations in vLLM. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration.")
|
||||
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Model Optimizer static quantization.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
activation scale. Future support might be added for dynamic
|
||||
scales.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
2. Only support float8_e4m3fn datatype
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# INPUT SCALE
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
weight = layer.weight
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
if not (layer.weight_scale == layer.weight_scale[0]).all():
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
layer.weight, layer.weight_scale, layer.logical_widths)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class ModelOptNvFp4Config(QuantizationConfig):
|
||||
"""Config class for ModelOpt FP4."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool,
|
||||
kv_cache_quant_algo: str,
|
||||
exclude_modules: list[str],
|
||||
group_size: int = 16,
|
||||
) -> None:
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt NVFP4 checkpoint. Please note that"
|
||||
" the format is experimental and could change in future.")
|
||||
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "modelopt_fp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
|
||||
" quantizations in vLLM. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration.")
|
||||
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
|
||||
if ("group_size" and "kv_cache_quant_algo"
|
||||
and "exclude_modules") not in quant_config:
|
||||
raise ValueError("NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in "
|
||||
"hf_quant_config.json")
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
import regex as re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoE(self)
|
||||
return None
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Union[ModelOptFp8Config,
|
||||
ModelOptNvFp4Config]):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Model Optimizer NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
|
||||
input_scale: torch.float32, scalar ,
|
||||
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
|
||||
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
|
||||
weight_scale_2: torch.float32, scalar,
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
if (input_size_per_partition % 16 != 0):
|
||||
raise ValueError("Unsupported model when in features size is "
|
||||
"not multiple of 16")
|
||||
# The nvfp4 weight is still represented as
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
else params_dtype)
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
layer.output_size_per_partition,
|
||||
layer.input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# Input Weight Scale
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
# Per Block Weight Scale
|
||||
weight_scale = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# global scales:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
|
||||
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
# Swizzle the weight blockscale.
|
||||
# contracting dimension is input dimension
|
||||
# block_size = 16;
|
||||
assert (layer.weight_scale.shape[1] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
s_quant = 1 / layer.input_scale
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
|
||||
|
||||
# validate dtypes of quantized input, input block scale,
|
||||
# weight and weight_blockscale
|
||||
assert (x_fp4.dtype == torch.uint8)
|
||||
assert (layer.weight.dtype == torch.uint8)
|
||||
assert (x_blockscale.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.alpha.dtype == torch.float32)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha,
|
||||
output_dtype)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization.
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
# GEMM 2
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
|
||||
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# GEMM 1
|
||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||
layer.w13_weight_scale_2[:, 1]):
|
||||
logger.warning_once(
|
||||
"w1_weight_scale_2 must match w3_weight_scale_2. "
|
||||
"Accuracy may be affected.")
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||
torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
del layer.w13_blockscale_swizzled
|
||||
del layer.w2_blockscale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for ModelOptNvFp4FusedMoE.")
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
469
model_executor/layers/quantization/moe_wna16.py
Normal file
469
model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class MoeWNA16Config(QuantizationConfig):
|
||||
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||
|
||||
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.bit8_pack_factor = 8 // self.weight_bits
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.linear_quant_method = linear_quant_method
|
||||
self.full_config = full_config
|
||||
self.use_marlin = False
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
"""
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
|
||||
full_config)
|
||||
elif self.linear_quant_method == "awq":
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
if device_capability < awq_min_capability:
|
||||
raise ValueError(
|
||||
"The quantization method moe_wna16 + awq is not supported "
|
||||
"for the current GPU. "
|
||||
f"Minimum capability: {awq_min_capability}. "
|
||||
f"Current capability: {device_capability}.")
|
||||
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
|
||||
full_config)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
"""
|
||||
if modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = []
|
||||
else:
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
if linear_quant_method == "gptq":
|
||||
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method == "awq":
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
return cls(linear_quant_method, weight_bits, group_size, has_zp,
|
||||
lm_head_quantized, modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
|
||||
gptq_compatible = quant_method == "gptq" and \
|
||||
not desc_act and num_bits in [4, 8]
|
||||
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
||||
device_capability >= awq_min_capability
|
||||
|
||||
return gptq_compatible or awq_compatible
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, LinearBase):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
if self.linear_quant_method == "gptq":
|
||||
if self.use_marlin:
|
||||
return GPTQMarlinConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
elif self.linear_quant_method == "awq":
|
||||
if self.use_marlin and check_marlin_supports_layer(
|
||||
layer, self.group_size):
|
||||
return AWQMarlinConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return MoeWNA16Method(self)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
|
||||
# make intermediate_size and hidden_size diviable by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size])
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: str,
|
||||
expert_id: int):
|
||||
if layer.ep_size > 1:
|
||||
global_expert_id = expert_id
|
||||
expert_id = layer._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
if "g_idx" in weight_name:
|
||||
return
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return
|
||||
|
||||
device = get_tp_group().device
|
||||
if layer.ep_size > 1:
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = loaded_weight.T.contiguous().view(
|
||||
torch.uint8)
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
if layer.ep_size > 1 :
|
||||
tensor = loaded_weight.view(-1, param.data[expert_id].shape[0] // 2,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
else:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
elif "w2_qzeros" in weight_name:
|
||||
if layer.ep_size > 1 :
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), -1, param.data[expert_id].shape[1])[:, tp_rank]
|
||||
else:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
else:
|
||||
if layer.ep_size > 1:
|
||||
expert_id = global_expert_id
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||
expert_id)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
76
model_executor/layers/quantization/neuron_quant.py
Normal file
76
model_executor/layers/quantization/neuron_quant.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
||||
|
||||
|
||||
class AlwaysSupportedDtypes(list):
|
||||
|
||||
def __contains__(self, item):
|
||||
return True
|
||||
|
||||
|
||||
class NeuronQuantConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for Neuron Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dequant_dtype: str = "f16",
|
||||
quantize_method: str = "vector_dynamic",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
|
||||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
|
||||
raise ValueError(
|
||||
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
|
||||
f" the quantization datatype should match one of the below "
|
||||
f"types {SUPPORTED_QUANT_DTYPE_LIST}")
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[str]:
|
||||
# Neuron implements custom handling logic for quantization support
|
||||
return AlwaysSupportedDtypes()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
quantize_method=quantize_method)
|
||||
|
||||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
return self.get_quantization_config()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Neuron Quantization is only supported through"
|
||||
" transformers_neuronx.")
|
||||
|
||||
def get_quantization_config(self):
|
||||
from transformers_neuronx.config import QuantizationConfig
|
||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||
dequant_dtype=self.dequant_dtype,
|
||||
quantize_method=self.quantize_method)
|
||||
127
model_executor/layers/quantization/ptpc_fp8.py
Normal file
127
model_executor/layers/quantization/ptpc_fp8.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on ROCm.")
|
||||
|
||||
if not current_platform.has_device_capability(94):
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
|
||||
)
|
||||
if activation_scheme == "static":
|
||||
raise ValueError(
|
||||
"ptpc_fp8 as of now only support dynamic quantization.")
|
||||
|
||||
super().__init__(is_checkpoint_fp8_serialized=False,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return PTPCFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
|
||||
Only supports loading quantized BF16 model checkpoints with dynamic
|
||||
activation scaling. To load FP16 model checkpoints, user must specify
|
||||
to convert the FP16 model weight loading into BF16.
|
||||
The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support float8_e4m3fnuz data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
assert layer.weight.data.dtype == torch.bfloat16, \
|
||||
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
|
||||
# Quantize the weights.
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(
|
||||
layer.weight, scale=None, use_per_token_if_dynamic=True)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(
|
||||
qweight.t(), requires_grad=False) # Pretranspose the weight
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=None,
|
||||
bias=bias)
|
||||
275
model_executor/layers/quantization/qqq.py
Normal file
275
model_executor/layers/quantization/qqq.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MARLIN_QQQ_TILE = 16
|
||||
MARLIN_QQQ_MIN_THREAD_N = 64
|
||||
MARLIN_QQQ_MIN_THREAD_K = 128
|
||||
MARLIN_QQQ_MAX_PARALLEL = 16
|
||||
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
||||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
class QQQConfig(QuantizationConfig):
|
||||
"""Config class for QQQ
|
||||
|
||||
Reference: https://arxiv.org/pdf/2406.09904
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
is_sym: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.is_sym = is_sym
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"QQQ does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"QQQ does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"QQQ does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
# Tile size used by QQQ kernels.
|
||||
self.tile_size = MARLIN_QQQ_TILE
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the QQQ kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "QQQConfig(weight_bits={}, group_size={})".format(
|
||||
self.weight_bits, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "qqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QQQLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return QQQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class QQQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for QQQ.
|
||||
|
||||
Args:
|
||||
quant_config: The QQQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QQQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||
if (self.quant_config.group_size != -1 and
|
||||
input_size_per_partition % self.quant_config.group_size != 0):
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}.")
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
s_channel = ChannelQuantScaleParameter(data=torch.empty(
|
||||
1,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.float,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
output_dim=1)
|
||||
|
||||
if self.quant_config.group_size == -1:
|
||||
s_group_data = torch.tensor(
|
||||
[],
|
||||
device="cuda",
|
||||
dtype=torch.half,
|
||||
)
|
||||
else:
|
||||
s_group_data = torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.half,
|
||||
)
|
||||
|
||||
s_group_attr = {"data": s_group_data, "weight_loader": weight_loader}
|
||||
|
||||
if self.quant_config.group_size == -1:
|
||||
s_group = BasevLLMParameter(**s_group_attr)
|
||||
else:
|
||||
s_group = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**s_group_attr)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("B", qweight)
|
||||
layer.register_parameter("s_channel", s_channel)
|
||||
layer.register_parameter("s_group", s_group)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B = Parameter(layer.B.data, requires_grad=False)
|
||||
layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False)
|
||||
layer.s_group = Parameter(layer.s_group.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B
|
||||
s_ch = layer.s_channel
|
||||
s_group = layer.s_group
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = s_ch.shape[1]
|
||||
|
||||
x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d)
|
||||
|
||||
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
|
||||
workspace, size_m, size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
441
model_executor/layers/quantization/quark/quark.py
Normal file
441
model_executor/layers/quantization/quark/quark.py
Normal file
@@ -0,0 +1,441 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self,
|
||||
module=layer,
|
||||
layer_name=prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
raise ValueError("The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, q_configs[0])
|
||||
for q_config in q_configs):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different.")
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError(
|
||||
"The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
return cls(quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("dtype") == "fp8_e4m3")
|
||||
is_static_weight = not weight_quant.get("is_dynamic")
|
||||
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"])
|
||||
|
||||
if not (is_fp8_dtype and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
||||
and input_quant.get("dtype") == "int8")
|
||||
|
||||
is_tensor = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"]
|
||||
and input_quant.get("qscheme") == "per_tensor")
|
||||
|
||||
is_static = (not weight_quant.get("is_dynamic")
|
||||
and not input_quant.get("is_dynamic"))
|
||||
|
||||
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
"weight_quant or input_quant not set")
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4" or input_quant.get(
|
||||
"dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
|
||||
"qscheme") != "per_group":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get(
|
||||
"group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Weights need to use static quantization.
|
||||
if weight_quant.get("is_dynamic") is True:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not weight static")
|
||||
return False
|
||||
|
||||
# Activations need to use dynamic quantization.
|
||||
if input_quant.get("is_dynamic") is False:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not activation dynamic")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
|
||||
"scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0])
|
||||
for q_config in shard_configs):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
input_static = (input_config is not None and
|
||||
not cast(bool, input_config.get("is_dynamic")))
|
||||
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=input_static)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"))
|
||||
elif self._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError("No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
f"Input config: {input_config}")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkScheme":
|
||||
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: QuarkConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_config: the quark kv cache scheme
|
||||
"""
|
||||
if kv_cache_config is None:
|
||||
return
|
||||
|
||||
dtype = kv_cache_config.get("dtype")
|
||||
if dtype != "fp8_e4m3":
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
f"dtype=fp8_e4m3, however received {dtype}")
|
||||
|
||||
qscheme = kv_cache_config.get("qscheme")
|
||||
if qscheme != "per_tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for quark KV cache. "
|
||||
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|
||||
237
model_executor/layers/quantization/quark/quark_moe.py
Normal file
237
model_executor/layers/quantization/quark/quark_moe.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkMoEMethod":
|
||||
layer_quant_config = quant_config._find_matched_config(
|
||||
layer_name, module)
|
||||
|
||||
if (layer_quant_config.get("output_tensors")
|
||||
or layer_quant_config.get("bias")):
|
||||
raise NotImplementedError("Currently, Quark models with "
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported")
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
||||
Any]):
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
weight_qscheme = self.weight_quant.get("qscheme")
|
||||
input_qscheme = self.input_quant.get("qscheme")
|
||||
if not (weight_qscheme == "per_tensor"
|
||||
and input_qscheme == "per_tensor"):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
if not envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.weight_quantizer = weight_quantizer
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
if envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
|
||||
else:
|
||||
dq_w = layer.weight
|
||||
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.qscheme == "per_tensor":
|
||||
if current_platform.is_rocm():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True))
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, 'weight_zero_point')
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, 'input_zero_point')
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
105
model_executor/layers/quantization/quark/utils.py
Normal file
105
model_executor/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
86
model_executor/layers/quantization/schema.py
Normal file
86
model_executor/layers/quantization/schema.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the Pydantic schemas for various quantization-related
|
||||
parameters. When a relevant quantization technique is specified, these
|
||||
parameters are loaded in the form of a JSON alongside the model weights
|
||||
and augment the model with additional information needed for use of that
|
||||
technique. The format of this JSON should be specified by one or more
|
||||
schemas contained here.
|
||||
|
||||
For example, when the KV cache is quantized to FP8-E4M3 (currently only
|
||||
possible on ROCm), the model can be optionally augmented with KV cache
|
||||
scaling factors.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
|
||||
|
||||
class KVCacheQuantSchema(BaseModel):
|
||||
dtype: str
|
||||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: dict[int, dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
assert self.dtype == "float8_e4m3fn", (
|
||||
"Loaded scaling factors intended for KV cache dtype = "
|
||||
f"{self.dtype} rather than float8_e4m3fn!")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_size = context["tp_size"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
assert len(self.scaling_factor) == tp_size, (
|
||||
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||
f"but LLM engine is currently running with TP size {tp_size}.")
|
||||
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||
assert len(layer_maps) == num_hidden_layers, (
|
||||
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||
f"Expected {num_hidden_layers} layers, got "
|
||||
f"{len(layer_maps)}.")
|
||||
for i in range(tp_size):
|
||||
assert i in self.scaling_factor, (
|
||||
f"KV cache scales map for TP rank {i} not found.")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_rank = context["tp_rank"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
layer_scales_map = self.scaling_factor[tp_rank]
|
||||
for i in range(num_hidden_layers):
|
||||
assert i in layer_scales_map, (
|
||||
f"Could not find KV cache scales for layer {i} in "
|
||||
f"TP rank {tp_rank}.")
|
||||
return self
|
||||
|
||||
|
||||
class QuantParamSchema(BaseModel):
|
||||
# TODO: Generalize and extend with more fields
|
||||
# (e.g. weights/activations params) once functionality is enabled
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_type: Optional[str]
|
||||
kv_cache: KVCacheQuantSchema
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
model_type = context.get("model_type", None)
|
||||
if model_type is not None:
|
||||
assert model_type == self.model_type, (
|
||||
f"Model type is {model_type} but loaded "
|
||||
f"scaling factors belonging to different "
|
||||
f"model type {self.model_type}!")
|
||||
return self
|
||||
161
model_executor/layers/quantization/torchao.py
Normal file
161
model_executor/layers/quantization/torchao.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
def __init__(self, torchao_config) -> None:
|
||||
self.torchao_config = torchao_config
|
||||
"""
|
||||
# TorchAO quantization relies on tensor subclasses. In order,
|
||||
# to enable proper caching this needs standalone compile
|
||||
if is_torch_equal_or_newer("2.8.0"):
|
||||
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
|
||||
logger.info(
|
||||
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
|
||||
|
||||
# TODO: remove after the torch dependency is updated to 2.8
|
||||
if is_torch_equal_or_newer(
|
||||
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
|
||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchao>=0.10.0 via "
|
||||
"`pip install torchao>=0.10.0` to use torchao quantization."
|
||||
) from err
|
||||
|
||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||
assert hf_config is not None, "quant_type must be specified"
|
||||
assert (len(hf_config) == 1 and "default" in hf_config
|
||||
), "Expected only one key 'default' in quant_type dictionary"
|
||||
quant_type = hf_config["default"]
|
||||
ao_config = config_from_dict(quant_type)
|
||||
return cls(ao_config)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
from torchao.quantization import ModuleFqnToConfig
|
||||
|
||||
module_fqn = prefix
|
||||
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||
c = module_fqn_to_config.get(
|
||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||
if c is not None:
|
||||
current_torchao_config = TorchAOConfig(c)
|
||||
return TorchAOLinearMethod(current_torchao_config)
|
||||
else:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
return TorchAOLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def torchao_quantize_param_data(param: torch.Tensor,
|
||||
torchao_config: Any) -> torch.nn.Parameter:
|
||||
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
||||
|
||||
Args:
|
||||
`param`: weight parameter of the linear module
|
||||
`torchao_config`: type of quantization and their arguments we want to
|
||||
use to quantize the Tensor
|
||||
"""
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||
dummy_linear.weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
return dummy_linear.weight
|
||||
|
||||
|
||||
class TorchAOLinearMethod(LinearMethodBase):
|
||||
"""Linear method for torchao.
|
||||
|
||||
Args:
|
||||
torchao_config: The torchao quantization config, a string
|
||||
that encodes the type of quantization and all relevant arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: TorchAOConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torchao_quantize_param_data(weight,
|
||||
self.quant_config.torchao_config)
|
||||
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
121
model_executor/layers/quantization/tpu_int8.py
Normal file
121
model_executor/layers/quantization/tpu_int8.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
ACTIVATION_SCHEMES = ["none"]
|
||||
|
||||
|
||||
class Int8TpuConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for TPU Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "none",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with TPU Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(self, layer: Module,
|
||||
prefix: str) -> Optional["TPUInt8LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return TPUInt8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class TPUInt8LinearMethod(LinearMethodBase):
|
||||
"""Int8 Linear method for TPU Quant. """
|
||||
|
||||
def __init__(self, quant_config: Int8TpuConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight_dtype = weight.dtype
|
||||
weight = weight.cpu().to(torch.float32)
|
||||
n_bit = 8
|
||||
eps = 1e-5
|
||||
max_int = 2**(n_bit - 1) - 1
|
||||
min_int = -(2**(n_bit - 1))
|
||||
max_val = weight.abs().amax(dim=-1, keepdim=True)
|
||||
max_val = max_val.clamp(min=eps)
|
||||
qscale = max_val / max_int
|
||||
qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
|
||||
max_int).to(torch.int8)
|
||||
qscale = qscale.squeeze().to(weight_dtype)
|
||||
return qweight, qscale
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
device = layer.weight.device
|
||||
qweight, qscale = self._quantize_weight(layer.weight)
|
||||
qweight = qweight.to(device)
|
||||
qscale = qscale.to(device)
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.scale = Parameter(qscale, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
try:
|
||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torch_xla by following the instructions at "
|
||||
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
|
||||
"to run vLLM on TPU.") from err
|
||||
weight = layer.weight
|
||||
scale = layer.scale
|
||||
out = torch.ops.xla.quantized_matmul(x, weight, scale)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
6
model_executor/layers/quantization/utils/__init__.py
Normal file
6
model_executor/layers/quantization/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .layer_utils import replace_parameter, update_tensor_inplace
|
||||
|
||||
__all__ = ['update_tensor_inplace', 'replace_parameter']
|
||||
52
model_executor/layers/quantization/utils/allspark_utils.py
Normal file
52
model_executor/layers/quantization/utils/allspark_utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||
|
||||
|
||||
def check_allspark_supported_dtype_shape(input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
group_size: int,
|
||||
weight_dtype: ScalarType,
|
||||
act_dtype: torch.dtype):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
# For Ampere GPU
|
||||
if device_capability >= 80 and device_capability < 90:
|
||||
if group_size != -1:
|
||||
return False, \
|
||||
"For Ampere GPU, AllSpark does not support group_size "\
|
||||
f"= {group_size}. Only group_size = -1 are supported."
|
||||
|
||||
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||
return False, "For Ampere GPU, AllSpark does not support "\
|
||||
f"quant type ({weight_dtype}). Only quant type "\
|
||||
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
|
||||
|
||||
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
|
||||
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
|
||||
return False, \
|
||||
"AllSpark needs input_size_per_partition % "\
|
||||
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
|
||||
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
|
||||
"for Ampere GPU optimized kernels."
|
||||
|
||||
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||
return False, \
|
||||
"AllSpark only supports act_dtype = float16 or bfloat16,"\
|
||||
f"for Ampere GPU, but got act_dtype = {act_dtype}."
|
||||
else:
|
||||
return False, "AllSpark currently does not support "\
|
||||
f"device_capability = {device_capability}."
|
||||
|
||||
return True, None
|
||||
208
model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
208
model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
MINIMUM_BITBLAS_VERSION = "0.1.0"
|
||||
|
||||
BITBLAS_MIN_WEIGHT_SIZE_N = 16
|
||||
BITBLAS_MIN_WEIGHT_SIZE_K = 16
|
||||
GPTQ_BITBLAS_MAX_PARALLEL = 16
|
||||
|
||||
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# For dynamic shape code generation
|
||||
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
# If want to enable high performance for contiguous batching
|
||||
# Please use the following values
|
||||
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
|
||||
BITBLAS_SUPPORTED_SYM = [False, True]
|
||||
|
||||
|
||||
# Determines the supported quantization types for BitBLAS based on the
|
||||
# device's capability and whether zero-point (zp) is used.
|
||||
def query_bitblas_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
if device_capability < 70:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def _check_bitblas_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_bitblas_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"BitBLAS does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
# Finally, check if bitblas is installed
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError:
|
||||
return False, "BitBLAS is not installed."
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def check_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
device_capability: Optional[int] = None) -> bool:
|
||||
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
|
||||
device_capability)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False) -> None:
|
||||
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def verify_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) -> None:
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
|
||||
raise ValueError(f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}."
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
|
||||
def check_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_bitblas_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
group_size)
|
||||
except ValueError as e:
|
||||
return False, e.__str__()
|
||||
return True, None
|
||||
|
||||
|
||||
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
|
||||
is_row_parallel: bool) -> bool:
|
||||
# Need to repeat scales on every rank if act_ordering or
|
||||
# channelwise and RowParallelLinear
|
||||
is_channelwise = group_size == -1
|
||||
return act_order or (is_channelwise and is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
|
||||
qzeros = qzeros.view(torch.int32)
|
||||
elems_per_int32 = 32 // bits
|
||||
unpacked_zeros = torch.zeros(
|
||||
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
|
||||
dtype=torch.int8,
|
||||
device=qzeros.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for col in range(unpacked_zeros.shape[1]):
|
||||
i = col % elems_per_int32
|
||||
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
|
||||
(bits * i)) & 0xF
|
||||
if not is_gptq_v2:
|
||||
return unpacked_zeros + 1
|
||||
return unpacked_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight(qweight, bits):
|
||||
qweight = qweight.view(torch.int8)
|
||||
elems_per_int8 = 8 // bits
|
||||
unpacked_weight = torch.zeros(
|
||||
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
|
||||
dtype=torch.int8,
|
||||
device=qweight.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
for col in range(unpacked_weight.shape[1]):
|
||||
i = col % elems_per_int8
|
||||
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
|
||||
(bits * i))
|
||||
|
||||
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user