first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationMethods = Literal[
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"modelopt_fp4",
"bitblas",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"hqq",
"experts_int8",
"ipex",
"quark",
"moe_wna16",
"torchao",
"auto-round",
"rtn",
"inc",
"mxfp4",
"petit_nvfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
def register_quantization_config(quantization: str):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from vllm.model_executor.layers.quantization import register_quantization_config
>>> from vllm.model_executor.layers.quantization import get_quantization_config
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
""" # noqa: E501
def _wrapper(quant_config_cls):
if quantization in QUANTIZATION_METHODS:
raise ValueError(
f"The quantization method `{quantization}` is already exists.")
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError("The quantization config must be a subclass of "
"`QuantizationConfig`.")
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
QUANTIZATION_METHODS.append(quantization)
return quant_config_cls
return _wrapper
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .auto_round import AutoRoundConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitblas import BitBLASConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_bitblas import GPTQBitBLASConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .inc import INCConfig
from .ipex_quant import IPEXConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"bitblas": BitBLASConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq_bitblas": GPTQBitBLASConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"rtn": RTNConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
return method_to_config[quantization]
__all__ = [
"QuantizationConfig",
"QuantizationMethods",
"get_quantization_config",
"QUANTIZATION_METHODS",
]

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
class AutoRoundConfig(QuantizationConfig):
"""Config class for AutoRound.
Reference: https://arxiv.org/pdf/2309.05516
"""
SUPPORTED_BITS = {2, 3, 4, 8}
SUPPORTED_DTYPES = {"int"}
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
SUPPORTED_BACKENDS = {
"auto",
"gptq",
"gptq:marlin",
"awq",
"awq:marlin",
"marlin",
"ipex",
}
def __init__(
self,
weight_bits: int,
group_size: int,
sym: bool = True,
packing_format: str = "auto_round:auto_gptq",
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
extra_config: Optional[dict[str, Any]] = None,
data_type: str = "int",
backend: str = "auto",
) -> None:
super().__init__()
if weight_bits not in self.SUPPORTED_BITS:
raise ValueError(f"Unsupported weight_bits: {weight_bits}, "
f"currently only support {self.SUPPORTED_BITS}")
if data_type not in self.SUPPORTED_DTYPES:
raise ValueError(
f"Unsupported data_type: {data_type},"
f" currently only support {self.SUPPORTED_DTYPES}")
if packing_format not in self.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported packing_format: {packing_format}, "
f"currently only support {self.SUPPORTED_FORMATS}")
if backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"Unsupported backend: {backend}, "
f"currently only support {self.SUPPORTED_BACKENDS}")
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.packing_format = packing_format
self.block_name_to_quantize = (block_name_to_quantize.split(",") if
isinstance(block_name_to_quantize, str)
else block_name_to_quantize)
self.extra_config = extra_config
self.data_type = data_type
self.backend = backend
self.pack_factor = Fraction(32, weight_bits)
def __repr__(self) -> str:
return (f"AutoRoundConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, sym={self.sym})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "auto-round"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantization_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
return cls(
weight_bits=cls.get_from_keys(config, ["bits"]),
group_size=cls.get_from_keys(config, ["group_size"]),
sym=cls.get_from_keys(config, ["sym"]),
packing_format=cls.get_from_keys_or(config, ["packing_format"],
"auto_round:auto_gptq"),
block_name_to_quantize=cls.get_from_keys_or(
config, ["block_name_to_quantize", "to_quant_block_names"],
None),
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"],
"auto"),
)
def get_layer_config(self, layer, layer_name: str):
def get_config(name: str, quantized: bool = True):
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
# 1. Exact match from config
if self.extra_config and layer_name in self.extra_config:
return get_config(layer_name)
# 2. Determine whether layer should be quantized
quantized = not isinstance(layer, ParallelLMHead)
if self.block_name_to_quantize:
quantized = any(
layer_name.startswith(name)
for name in self.block_name_to_quantize)
# 3. Handle fused MoE
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(
):
moe_configs = [
get_config(name, quantized) for name in self.extra_config
if name.startswith(layer_name)
]
if moe_configs:
if len(set(moe_configs)) == 1:
return moe_configs[0]
raise ValueError(f"Fused MoE layer '{layer_name}' requires "
f"consistent quant config for all sub-layers")
# 4. Handle fused QKV or other patterns
if self.extra_config:
for fusion_key, sub_keys in self.packed_modules_mapping.items():
if fusion_key in layer_name and layer_name.count(
fusion_key) == 1:
sub_names = [
layer_name.replace(fusion_key, sub_key)
for sub_key in sub_keys
]
sub_configs = [
get_config(name, quantized) for name in sub_names
]
if len(set(sub_configs)) == 1:
return sub_configs[0]
raise ValueError(
f"Fused module '{layer_name}' requires "
f"consistent quant config for {sub_names}")
# 5. Fallback
return get_config(layer_name, quantized)
def check_quantized(self, weight_bits: int) -> bool:
return weight_bits < 16
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.block_name_to_quantize is not None:
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
self.block_name_to_quantize)
if self.extra_config is not None:
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits
in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
quant_args_marlin = AWQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
lm_head_quantized=False,
full_config={},
modules_to_not_convert=[],
)
else:
from vllm.model_executor.layers.quantization.awq import (
AWQConfig, AWQLinearMethod)
quant_args = AWQConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
)
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
config = {
"quant_method": "awq",
"bits": weight_bits,
"group_size": group_size,
"zero_point": not sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return AWQMarlinLinearMethod(quant_args_marlin)
else:
return AWQLinearMethod(quant_args)
return None
def apply_gptq_quant_layer(self,
layer,
prefix: str,
backend: str = "auto"):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer)
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = (weight_bits,
sym) in GPTQ_TYPE_MAP and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)],
group_size,
has_zp=not sym)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size)
else:
use_marlin = False
if use_marlin:
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
quant_args_marlin = GPTQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
is_sym=sym,
lm_head_quantized=False,
desc_act=False,
dynamic={},
full_config={},
)
else:
from vllm.model_executor.layers.quantization.gptq import (
GPTQConfig, GPTQLinearMethod)
quant_args = GPTQConfig(
weight_bits=weight_bits,
group_size=group_size,
lm_head_quantized=False,
desc_act=False,
dynamic={},
)
if isinstance(layer, FusedMoE):
if use_marlin:
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
else:
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
config = {
"quant_method": "gptq",
"bits": weight_bits,
"group_size": group_size,
"sym": sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return GPTQMarlinLinearMethod(quant_args_marlin)
else:
return GPTQLinearMethod(quant_args)
return None
def apply_ipex_quant_layer(self, layer, prefix: str):
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
from vllm.model_executor.layers.quantization.ipex_quant import (
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if "awq" in self.packing_format:
config = IPEXConfig(method="awq",
weight_bits=weight_bits,
group_size=group_size)
return IPEXAWQLinearMethod(config)
elif "gptq" in self.packing_format:
config = IPEXConfig(method="gptq",
weight_bits=weight_bits,
group_size=group_size)
return IPEXGPTQLinearMethod(config)
else:
raise ValueError(
f"ipex backend only supports awq "
f"and gtpq format,but got {self.packing_format}")
else:
return None
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
if (current_platform.is_cpu() or current_platform.is_xpu()
or self.backend == "ipex"):
return self.apply_ipex_quant_layer(layer, prefix)
if "gptq" in self.packing_format or "gptq" in self.backend:
return self.apply_gptq_quant_layer(layer, prefix)
if "awq" in self.packing_format or "awq" in self.backend:
return self.apply_awq_quant_layer(layer, prefix)

View File

@@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
logger = init_logger(__name__)
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[list[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@staticmethod
def get_config_filenames() -> list[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json",
]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
config = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
marlin_compatible_config_dict = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
"modules_to_not_convert": self.modules_to_not_convert,
}
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
if input_size_per_partition % group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)

View File

@@ -0,0 +1,554 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[self.weight_bits]
verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.zero_point)
def __repr__(self) -> str:
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert, config)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix,
)
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if is_layer_skipped_awq(
prefix, getattr(self, "modules_to_not_convert", [])):
return UnquantizedFusedMoEMethod(layer.moe_config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMoEMethod(self, layer.moe_config)
return None
@classmethod
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda():
return False
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or zero_point is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=zero_point)
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: AWQMarlinConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
True,
"quant_method":
FusedMoeWeightScaleSupported.GROUP.value,
})
w13_qweight = Parameter(
torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size_per_partition,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size_per_partition * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(
torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace)

View File

@@ -0,0 +1,320 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
@triton.jit
def awq_dequantize_kernel(
qweight_ptr, # quantized matrix
scales_ptr, # scales, per group
zeros_ptr, # zeros, per group
group_size, # Should always be one of the supported group sizes
result_ptr, # Output matrix
num_cols, # input num cols in qweight
num_rows, # input num rows in qweight
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr):
# Set up the pids.
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
masks_y = offsets_y < num_rows
masks_x = offsets_x < num_cols
masks = masks_y[:, None] & masks_x[None, :]
# Compute offsets and masks for result output ptr.
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(
0, BLOCK_SIZE_X * 8)
result_offsets = (8 * num_cols * result_offsets_y[:, None] +
result_offsets_x[None, :])
result_masks_y = result_offsets_y < num_rows
result_masks_x = result_offsets_x < num_cols * 8
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
tl.arange(0, 4)[:, None]).reshape(8)
# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights = (iweights >> shifts) & 0xF
# Compute zero offsets and masks.
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
zero_masks_y = zero_offsets_y < num_rows // group_size
zero_masks_x = zero_offsets_x < num_cols
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF
# Compute scale offsets and masks.
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
scale_offsets_x[None, :])
scale_masks_y = scale_offsets_y < num_rows // group_size
scale_masks_x = scale_offsets_x < num_cols * 8
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
# Dequantize.
iweights = (iweights - zeros) * scales
iweights = iweights.to(result_ptr.type.element_ty)
# Finally, store.
tl.store(result_ptr + result_offsets, iweights, result_masks)
@triton.jit
def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
group_size, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
SPLIT_K: tl.constexpr):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
accumulator_dtype = c_ptr.type.element_ty
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# accumulator = tl.arange(0, BLOCK_SIZE_N)
# accumulator = tl.broadcast_to(accumulator[None, :],
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
dtype=accumulator_dtype)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
tl.arange(0, 4)[:, None]).reshape(8)
# Create the necessary shifts to use to unpack.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :],
(BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_bn = offsets_bn < N // 8
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_zn = offsets_zn < N // 8
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
masks_sn = offsets_sn < N
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
# block_offset = BLOCK_SIZE_K * SPLIT_K
# for k in range(0, (K + block_offset - 1) // (block_offset)):
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
# Dequantize b.
offsets_szk = (
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
tl.arange(0, 1))
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF
b = (b - zeros) * scales
b = b.to(c_ptr.type.element_ty)
# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
offsets_k += BLOCK_SIZE_K * SPLIT_K
a_ptrs += BLOCK_SIZE_K * SPLIT_K
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
c = accumulator.to(c_ptr.type.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# qweights - [K , M // 8], int32
# scales - [K // G, M ], float16
# zeros - [K // G, M // 8], int32
def awq_dequantize_triton(qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
block_size_x: int = 32,
block_size_y: int = 32) -> torch.Tensor:
K = qweight.shape[0]
M = scales.shape[1]
group_size = qweight.shape[0] // scales.shape[0]
assert K > 0 and M > 0
assert scales.shape[0] == K // group_size and scales.shape[1] == M
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result = torch.empty(qweight.shape[0],
qweight.shape[1] * 8,
device=qweight.device,
dtype=scales.dtype)
Y = qweight.shape[0] # num rows
X = qweight.shape[1] # num cols
grid = lambda META: (
triton.cdiv(X, META['BLOCK_SIZE_X']),
triton.cdiv(Y, META['BLOCK_SIZE_Y']),
)
awq_dequantize_kernel[grid](qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y)
return result
# input - [M, K]
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def awq_gemm_triton(input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32) -> torch.Tensor:
M, K = input.shape
N = qweight.shape[1] * 8
group_size = qweight.shape[0] // qzeros.shape[0]
assert N > 0 and K > 0 and M > 0
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
assert scales.shape[0] == K // group_size and scales.shape[1] == N
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
assert split_k_iters <= 32
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
N, META['BLOCK_SIZE_N']),
split_k_iters,
)
result = torch.zeros((split_k_iters, M, N),
dtype=scales.dtype,
device=input.device)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel[grid](input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters)
result = result.sum(0)
return result

View File

@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
# Not required functions
def embedding(self, layer: torch.nn.Module, *args,
**kwargs) -> torch.Tensor:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
def method_has_implemented_embedding(
method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None
and class_embedding is not base_embedding)
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return None
@staticmethod
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@staticmethod
def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any:
"""Get an optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default
@abstractmethod
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> Optional[str]:
return None
def apply_vllm_mapper( # noqa: B027
self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
def maybe_update_config(self, model_name: str): # noqa: B027
"""
Interface to update values after config initialization.
"""
pass

View File

@@ -0,0 +1,464 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from packaging import version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class BitBLASConfig(QuantizationConfig):
"""Config class for BitBLAS.
Reference: https://github.com/Microsoft/BitBLAS
"""
TORCH_DTYPE = torch.float16
STORAGE_DTYPE = "int8" # assume int8 storage
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# gptq_with_bitblas prefer "quantized implementation"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: Optional[int],
desc_act: Optional[bool],
is_sym: Optional[bool],
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if version.parse(bitblas.__version__) < version.parse(
MINIMUM_BITBLAS_VERSION):
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
storage_dtype = self.STORAGE_DTYPE
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
self.storage_dtype = storage_dtype
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
def __repr__(self) -> str:
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: dict[str, Any],
keys: list[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
return default
@classmethod
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
is_sym = cls.get_from_keys(config, ["sym"], False)
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
or hf_quant_cfg.get("is_bitblas_format", False))
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "bitblas")
if is_bitblas_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return BitBLASLinearMethod(self)
return None
class BitBLASLinearMethod(LinearMethodBase):
"""Linear method for BitBLAS.
Args:
quant_config: The BitBLAS quantization config.
"""
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
# Instead of BITBLAS_OPTIMIZE_FEATURES
# If you want to high contiguous batching
# performance
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING = True
BITBLAS_DTYPES = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
def __init__(self, quant_config: BitBLASConfig):
self.quant_config = quant_config
def create_weights_gptq(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized
weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_partition_sizes: List of output partition sizes.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the input
size per partition is not divisible by the group size
in `quant_config`.
"""
del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype not in self.quant_config.get_supported_act_dtypes():
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
group_size = self.quant_config.group_size
if group_size is None:
group_size = -1
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if (group_size != -1 and input_size_per_partition % group_size != 0):
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({group_size}).")
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self._configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
enable_tuning=self.ENABLE_TUNING,
bias=False,
layout="nt",
bits=self.quant_config.weight_bits,
)
# Initialize quantized weights with dimensions
# Quantized 4Bit weights packed.
qweight = PackedvLLMParameter(
data=torch.empty(
self.bitblas_matmul.retrieve_weight_shape(),
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
if self.bitblas_matmul.propagate_b else None),
weight_loader=weight_loader,
)
# Compute the number of input groups for channel-wise quantization.
input_groups = (1 if group_size == -1 else input_size_per_partition //
group_size)
# Initialize scales and zeros for the quantized weights.
weight_scale_args = {
"data":
torch.empty(
output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if self.quant_config.zeros_mode == "quantized":
zeros = PackedvLLMParameter(
data=torch.empty(
input_groups,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
else:
zeros = BasevLLMParameter(
torch.empty(output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype),
weight_loader=weight_loader,
)
# Set attributes to indicate how scales and zeros are applied.
set_weight_attrs(zeros, {
"input_dim": None if input_groups == 1 else 1,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("zeros", zeros)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.quant_method == "gptq":
return self.create_weights_gptq(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
out_dtype="float16",
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
with_scaling = False
with_zeros = False
group_size = self.quant_config.group_size
zeros_mode = self.quant_config.zeros_mode
if self.quant_config.quant_method == "gptq":
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if self.quant_config.is_sym:
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
matmul_config = MatmulConfig(
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=out_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=self.quant_config.STORAGE_DTYPE,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
logger.info(TUNING_MESSAGE)
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNED_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNED_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created."
logger.info(_message)
else:
_message = (
f"BitBLAS Operator {config} found in global_operator_cache.")
logger.info(_message)
return bitblas_matmul
def apply_gptq(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.zeros
x_2d = x.view(-1, x.shape[-1])
if self.quant_config.is_sym:
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
else:
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
def apply(
self,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if self.quant_config.quant_method == "gptq":
return self.apply_gptq(*args, **kwargs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")

View File

@@ -0,0 +1,627 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig):
"""Config class for BitsAndBytes Quantization.
Reference: https://arxiv.org/abs/2305.14314
"""
def __init__(
self,
load_in_8bit: bool = False,
load_in_4bit: bool = True,
bnb_4bit_compute_dtype: str = "float32",
bnb_4bit_quant_storage: str = "uint8",
bnb_4bit_quant_type: str = "fp4",
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
if self.bnb_4bit_quant_storage not in ["uint8"]:
raise ValueError("Unsupported bnb_4bit_quant_storage: "
f"{self.bnb_4bit_quant_storage}")
def __repr__(self) -> str:
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> QuantizationMethods:
return "bitsandbytes"
@classmethod
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
def get_safe_value(config, keys, default_value=None):
try:
value = cls.get_from_keys(config, keys)
return value if value is not None else default_value
except ValueError:
return default_value
load_in_8bit = get_safe_value(config, ["load_in_8bit"],
default_value=False)
load_in_4bit = get_safe_value(config, ["load_in_4bit"],
default_value=True)
bnb_4bit_compute_dtype = get_safe_value(config,
["bnb_4bit_compute_dtype"],
default_value="float32")
bnb_4bit_quant_storage = get_safe_value(config,
["bnb_4bit_quant_storage"],
default_value="uint8")
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
default_value="fp4")
bnb_4bit_use_double_quant = get_safe_value(
config, ["bnb_4bit_use_double_quant"], default_value=False)
llm_int8_enable_fp32_cpu_offload = get_safe_value(
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
llm_int8_has_fp16_weight = get_safe_value(config,
["llm_int8_has_fp16_weight"],
default_value=False)
llm_int8_skip_modules = get_safe_value(config,
["llm_int8_skip_modules"],
default_value=[])
llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
default_value=6.0)
return cls(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
llm_int8_skip_modules=llm_int8_skip_modules,
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self, layer.moe_config)
return None
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split('.')
# Check if any of the skip modules exactly matches any component
substr_check = any(module_name in components
for module_name in llm_int8_skip_modules)
# Allow certain layers to not be quantized
set_components = set(".".join(components[:i + 1])
for i in range(len(components)))
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
return substr_check or prefix_check
def calculate_quant_ratio(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
else:
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if version.parse(
bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params
def create_qweight_for_8bit():
qweight = Int8Params(
data=torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 0,
"pack_factor": 1,
"use_bitsandbytes_8bit": True,
"generation": 0
})
return qweight
def create_qweight_for_4bit():
quant_ratio = calculate_quant_ratio(params_dtype)
total_size = input_size_per_partition * sum(output_partition_sizes)
if total_size % quant_ratio != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape.")
qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
1,
dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 0,
"pack_factor": quant_ratio,
"use_bitsandbytes_4bit": True
})
return qweight
if self.quant_config.load_in_8bit:
qweight = create_qweight_for_8bit()
else:
qweight = create_qweight_for_4bit()
# Enable parameters to have the same name as in the BNB
# checkpoint format.
layer.register_parameter("weight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.load_in_8bit:
return self._apply_8bit_weight(layer, x, bias)
else:
return self._apply_4bit_weight(layer, x, bias)
def _apply_8bit_weight(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import MatmulLtState, matmul
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.weight
offsets = qweight.bnb_shard_offsets
quant_states = qweight.bnb_quant_state
matmul_states = qweight.matmul_state
generation = qweight.generation
out_dim_0 = x.shape[0]
out_dim_1 = sum(
[quant_state[1].shape[0] for quant_state in quant_states.items()])
out = torch.empty(out_dim_0,
out_dim_1,
dtype=torch.float16,
device=x.device)
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# in profile_run or the first generation of inference,
# create new matmul_states
if generation == 0 or generation == 1:
matmul_states[i] = MatmulLtState()
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
matmul_states[i].SCB = quant_states[i].to(x.device)
matmul_states[i].threshold = (
self.quant_config.llm_int8_threshold)
matmul_states[i].has_fp16_weights = (
self.quant_config.llm_int8_has_fp16_weight)
matmul_states[i].is_training = False
if matmul_states[i].threshold > 0.0 and not matmul_states[
i].has_fp16_weights:
matmul_states[i].use_pool = True
new_x = bf_x.unsqueeze(0)
out[:, current_index:current_index + output_size] = matmul(
new_x,
qweight[offsets[i]:offsets[i + 1]],
state=matmul_states[i])
current_index += output_size
# only update the matmul_states if it is not profile_run
if (generation > 0
and not self.quant_config.llm_int8_has_fp16_weight
and matmul_states[i].CB is not None
and matmul_states[i].CxB is not None):
del matmul_states[i].CB
qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias
qweight.generation += 1
return out
def _apply_4bit_weight(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.weight
quant_states = qweight.bnb_quant_state
offsets = qweight.bnb_shard_offsets
out_dim_0 = x.shape[0]
out_dim_1 = sum(
[quant_state[1].shape[0] for quant_state in quant_states.items()])
out = torch.empty(out_dim_0,
out_dim_1,
dtype=torch.bfloat16,
device=x.device)
apply_bnb_4bit(bf_x, qweight, offsets, out)
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias
return out
def _apply_bnb_4bit(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
quant_states = weight.bnb_quant_state
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
def _apply_bnb_4bit_fake(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
return
try:
direct_register_custom_op(op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
dispatch_key=current_platform.dispatch_key)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
except AttributeError as error:
raise error
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
"""MoE method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def __init__(
self,
quant_config: BitsAndBytesConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
try:
import bitsandbytes
if version.parse(
bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.load_in_8bit:
call_fun = self._create_weights_8bit
else:
call_fun = self._create_weights_4bit
call_fun(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
w13, w2 = self._apply_4bit_dequnt(layer)
return fused_experts(
hidden_states=x,
w1=w13,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
def _create_weights_4bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
quant_ratio = calculate_quant_ratio(params_dtype)
# Fused gate_up_proj (column parallel)
w13_total_size = (hidden_size * 2 *
intermediate_size_per_partition) // quant_ratio
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
set_weight_attrs(
w13_qweight,
{
"num_experts":
num_experts,
"input_dim":
hidden_size,
"output_dim":
2 * intermediate_size_per_partition,
"experts_shape": (
num_experts,
intermediate_size_per_partition * 2,
hidden_size,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
# down_proj (row parallel)
w2_total_size = (hidden_size *
intermediate_size_per_partition) // quant_ratio
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w2_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
w2_qweight,
{
"num_experts":
num_experts,
"input_dim":
intermediate_size_per_partition,
"output_dim":
hidden_size,
"experts_shape": (
num_experts,
hidden_size,
intermediate_size_per_partition,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
layer.register_parameter("w2_weight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
def _create_weights_8bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def _apply_4bit_dequnt(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
from bitsandbytes.functional import dequantize_4bit
w13 = dequantize_4bit(
layer.w13_weight.reshape(-1, 1),
layer.w13_weight.bnb_quant_state,
)
w2 = dequantize_4bit(
layer.w2_weight.reshape(-1, 1),
layer.w2_weight.bnb_quant_state,
)
w13 = w13.reshape(layer.w13_weight.experts_shape)
w2 = w2.reshape(layer.w2_weight.experts_shape)
return w13, w2
def _apply_8bit_dequant(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,797 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
if transform_config:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
self.sparsity_scheme_map)
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
self.sparsity_ignore_list)
if self.kv_cache_scheme is not None:
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
self.kv_cache_scheme)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config,
self.packed_modules_mapping)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, quant_scheme, input_tfms, output_tfms)
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)
@classmethod
def _parse_sparsity_config(
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements
1. A dictionary mapping target layer names to their corresponding
sparsity_config
2. A list of layer names to ignore for sparsity
"""
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
return dict(), []
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
sparsity_ignore_list = sparsity_config.ignore or list()
return sparse_scheme_map, sparsity_ignore_list
@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get(
"format")
format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config
act_quant_format = is_activation_quantization_format(
format
) if format is not None else is_activation_quantization_format(
quant_format)
# TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations:
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# there is an input_quant but it is ignored
if not input_activations:
assert target_scheme_map[target][
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations"))
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self,
min_capability: int,
error: bool = True,
match_exact: bool = False) -> bool:
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
if match_exact:
supported = capability == min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
"the current GPU. Required capability: ",
f"{min_capability}. Current capability: {capability}.")
else:
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
else:
return False
def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (weight_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy
== QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT.value)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)
def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value)
is_symmetric = weight_quant.symmetric
is_group_size_16 = weight_quant.group_size == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4
return (is_weight_only and is_tensor_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric)
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_static = not weight_quant.dynamic and not input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and weight_quant.symmetric and is_dynamic)
def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
])
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_tensor_or_channel_or_block_weight):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True
# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
if not weight_quant or not input_quant:
return False
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and is_symmetric and is_dynamic)
def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w4a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
return (self._check_scheme_supported(
100, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
])
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_tensor_or_channel_or_block_weight):
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
input_quant_none = input_quant is None
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic
return (is_channel_group and input_quant_none and is_static)
def _get_scheme_from_parts(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: Optional[str] = None) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if (format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported(
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4.")
return CompressedTensorsW4A16Fp4(
has_input_global_scale=True)
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
weight_quant=weight_quant,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=not input_quant.dynamic)
# note: input_quant can be None
if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = (input_quant
and not input_quant.dynamic)
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=is_static_input_scheme)
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
is_static_input_scheme = (input_quant
and not input_quant.dynamic)
return CompressedTensorsW4A8Int(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_quant.symmetric)
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def get_scheme(self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]:
"""
compressed-tensors supports non uniform in the following way:
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for inference.
"""
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
# Find the sparsity scheme of the layer
# assume that fused layers inherit first component's sparsity scheme
sparsity_targets = (self.sparsity_scheme_map.keys() -
set(self.sparsity_ignore_list))
sparsity_scheme: Optional[SparsityCompressionConfig] = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping)
sparsity_scheme = self.sparsity_scheme_map[matched_target]
if self.supports_cutlass_24(weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config = (None if sparsity_scheme is None
or sparsity_scheme.format == "dense"
else self.config)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
return None
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant,
input_quant=input_quant,
format=format)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
layer_name)
return scheme
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
if sparsity_scheme is None:
return False
is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}
is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)
if not is_valid_sparsity:
return False
# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True
# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False
supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value
]
assert weight_quant is not None
assert input_quant is not None
if weight_quant.strategy not in supported_weight_quant_strategies:
return False
supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
]
if input_quant.strategy not in supported_input_quant_strategies:
return False
return weight_quant.num_bits == input_quant.num_bits == 8
class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
__all__ = [
"CompressedTensorsScheme", "CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8"
]

View File

@@ -0,0 +1,366 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensors24"]
from vllm.platforms import current_platform
class CompressedTensors24(CompressedTensorsScheme):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[dict[str, Any]] = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)
if quantized and input_quant is not None and \
self._get_quant_dtype() == current_platform.fp8_dtype():
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature")
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
input_scale = torch.nn.Parameter(data=torch.ones(
1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(
convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
),
requires_grad=False,
)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
# Set all negative zero values to 0 prior to compression
if (layer.weight.dtype.is_floating_point
and layer.weight.dtype.itemsize >= 2):
layer.weight.data[layer.weight.data == -0.0] = 0.0
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = getattr(layer, 'input_scale', None)
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
q_input, input_scale = self.quant_fp8(x, scale=scale)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(
a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
return self._get_quant_dtype()
def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT):
return torch.float8_e4m3fn
if (self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: list[torch.Tensor] = []
split_bitmask: list[torch.Tensor] = []
split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
))
return decompressed

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.strategy = strategy
self.group_size = group_size
self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")
@classmethod
def get_min_capability(cls) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data,
requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)
input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if self.group_size is not None:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
else:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("scale_packed", scales)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.quant_type, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW4A16Fp4"]
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
def __init__(self, has_input_global_scale: bool = False):
self.has_input_global_scale = has_input_global_scale
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_global_scale", weight_global_scale)
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
if self.has_input_global_scale:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
del layer.weight_global_scale
if self.has_input_global_scale:
layer.input_global_scale = torch.nn.Parameter(
layer.input_global_scale.data, requires_grad=False)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp4_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
logger.info_once("Using flashinfer-trtllm for FP4")
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
logger.info_once("Using flashinfer-cutlass for FP4")
else:
self.backend = "cutlass"
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return 80
return 100
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_global_scale", weight_global_scale)
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(global_input_scale,
requires_grad=False)
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
if self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8),
epilogue_tile_m)
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
out = run_nvfp4_emulations(
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale)
if bias is not None:
out = out + bias
return out
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
layer.weight_packed,
x_blockscale.view(-1).view(torch.uint8),
layer.weight_scale,
layer.alpha,
use_mx=False,
).to(output_dtype)
else:
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Fp8"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size != 128 or self.strategy != "group":
raise ValueError("W4A8 kernels require group quantization " \
"with group size 128")
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# hopper
return 90
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx,
out_type=params_dtype
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Fp8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=torch.float8_e4m3fn,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
# per-channel scales
weight_chan_scale = ChannelQuantScaleParameter(
data=torch.empty((output_size_per_partition, 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_chan_scale", weight_chan_scale)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Int"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
is_static_input_scheme: bool = False,
input_symmetric: bool = True):
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}."
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
return 1
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
row_parallel = (input_size != input_size_per_partition)
# Compute effective group_size
if self.group_size == -1:
effective_group_size = (input_size_per_partition
if row_parallel else input_size)
else:
effective_group_size = self.group_size
# Ensure group_size divides input_size_per_partition
assert input_size_per_partition % effective_group_size == 0, (
f"input_size_per_partition {input_size_per_partition}"
f" not divisible by group_size {effective_group_size}")
# Determine scale partitioning
is_channelwise = (self.group_size == -1)
repeat_scales = (is_channelwise and row_parallel)
partition_scales = not repeat_scales
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(input_size_per_partition,
output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=effective_group_size,
zero_points=False,
has_g_idx=False,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Int",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
scales_and_zp_size = input_size_per_partition // effective_group_size
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype)
}
if partition_scales:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
else:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
]
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy)
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy, validate_fp8_block_shape)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs,
is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
self.weight_block_size = self.weight_quant.block_structure
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
weight_loader: Callable, **kwargs):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None
if self.strategy == QuantizationStrategy.BLOCK:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape(layer, input_size, output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size)
# WEIGHT
weight = create_fp8_weight_parameter(output_size_per_partition,
input_size_per_partition,
weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy], output_partition_sizes,
input_size_per_partition, layer.weight_block_size, weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_fp8_input_scale(output_partition_sizes,
weight_loader)
layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
weight, weight_scale, input_scale = (
process_fp8_weight_tensor_strategy(
layer.weight, layer.weight_scale, layer.logical_widths,
getattr(layer, 'input_scale', None)))
weight = weight.t()
elif self.strategy == QuantizationStrategy.CHANNEL:
weight, weight_scale, input_scale = (
process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale,
getattr(layer, 'input_scale', None)))
weight = weight.t()
elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale)
input_scale = None
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# required by torch.compile to be torch.nn.Parameter
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
if input_scale is not None:
layer.input_scale = Parameter(input_scale.data,
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(
layer, self.cutlass_block_fp8_supported)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if layer.weight_block_size is not None:
return apply_fp8_block_linear(
layer,
input=x,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128
}
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
if not self.symmetric else
WNA16_SUPPORTED_TYPES_MAP[num_bits])
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
)
}
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
output_size_per_partition // self.pack_factor,
scales_and_zp_size,
dtype=torch.int32,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from itertools import accumulate
from typing import Callable, Optional
import torch
from compressed_tensors.transform import (TransformArgs, TransformConfig,
TransformLocation, TransformScheme)
from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@classmethod
def from_schemes(
cls,
quant_method: LinearMethodBase,
quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod":
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)
assert input_tfms or output_tfms
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
return QutlassNvFP4LinearMethod(quant_method, input_tfms,
output_tfms)
# hadacore or dense gemm is selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
def __init__(self, quant_method: LinearMethodBase,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple]):
self.quant_method = quant_method
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: Optional[HadamardTransform] = None
self.output_transform: Optional[HadamardTransform] = None
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# get weight loader for transforms
weight_loader: Callable = extra_weight_attrs.get(
"weight_loader") # type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
if isinstance(layer, QKVCrossParallelLinear):
weight_loader_v1 = layer.weight_loader_v1
else:
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
**extra_weight_attrs)
# validate schemes
num_partitions = len(output_partition_sizes)
self._validate_tfm_schemes(num_partitions)
# create submodules for weight loading
if len(self.input_tfms) > 0:
scheme_name = list(self.input_tfms.values())[0].scheme_name
location = list(self.input_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.input_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.input_transform = transform
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.output_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.output_transform = transform
# compute partition ranges for slicing activations
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
self.partition_ranges = list(zip(starts, output_partition_sizes))
def process_weights_after_loading(self, layer):
self.quant_method.process_weights_after_loading(layer)
for submodule in layer.children():
if isinstance(submodule, HadamardTransform):
submodule.process_weights_after_loading()
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# In most cases, input transforms are preferred over output transforms
# (@ksayers): confirm that this is done concurrently
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length].contiguous(), part_id=part_id)
return x
def _validate_tfm_schemes(self, num_partitions: int):
if len(self.input_tfms) > 0:
if 0 not in self.input_tfms:
raise ValueError("Must have same input")
for part_index in range(num_partitions):
if self.input_tfms[part_index] != self.input_tfms[0]:
raise ValueError("Must have same input")
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
for tfm in self.output_tfms.values():
if tfm.scheme_name != scheme_name:
raise ValueError("Must have same scheme name")
if tfm.args.location != location:
raise ValueError("Must have same location")
return self.input_tfms, self.output_tfms
def get_linear_transform_schemes(
layer: torch.nn.Module, layer_name: str,
transform_config: Optional[TransformConfig],
packed_modules_mapping: dict[str, list[str]]
) -> tuple[dict[int, TransformTuple], dict[
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms = {}
output_tfms = {}
partition_names = get_layer_partition_names(layer_name,
packed_modules_mapping)
for scheme_name, scheme, args in get_schemes_args(transform_config):
for part_index, part_name in enumerate(partition_names):
if is_match(part_name, layer, args.targets,
args.ignore) and args.is_online():
if args.location == TransformLocation.INPUT:
input_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
elif args.location == TransformLocation.OUTPUT:
output_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
else:
raise ValueError(f"Cannot apply `{args.location}` "
f"transform to `{layer_name}`")
return (input_tfms, output_tfms)
def get_schemes_args(
transform_config: Optional[TransformConfig]
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return
for scheme_name, scheme in transform_config.config_groups.items():
for args in scheme.apply:
yield (scheme_name, scheme, args)
def get_layer_partition_names(
layer_name: str, packed_modules_mapping: dict[str,
list[str]]) -> list[str]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for fused_suffix, part_suffixes in packed_modules_mapping.items():
if layer_name.endswith(fused_suffix):
return [
layer_name.removesuffix(fused_suffix) + part_suffix
for part_suffix in part_suffixes
]
return [layer_name]

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable
import torch
from compressed_tensors.transform import (TransformArgs, TransformLocation,
TransformScheme)
from torch import Tensor
import vllm._custom_ops as ops
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parameter import SharedWeightParameter
class HadamardTransform(torch.nn.Module):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self, transforms: dict[int, TransformTuple],
layer: torch.nn.Module, weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int]):
super().__init__()
self.transforms = transforms
self.scales = {}
if get_tensor_model_parallel_world_size() > 1:
raise NotImplementedError("Online transforms with tensor "
"parallelism is not supported")
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self.weight = SharedWeightParameter(weight_loader=weight_loader)
# create shared partition data for each partition of the original weight
input_size = input_size_per_partition
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, scheme, args,
input_size, output_size)
data_key = self._get_data_key(scheme, weight_size)
self.weight.add_partition(
part_index,
data_key,
size=(weight_size, weight_size),
dtype=scheme.precision,
)
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
# required by torch.compile
self.weight.process_weights_after_loading()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self.scales[part_id] = 1 / math.sqrt(data.size(0))
# FUTURE: avoid runtime transpose by processing weights
# prior to apply
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value
# use hadacore if possible
if self.transforms[part_id].scheme.type == "hadamard":
if self.transforms[part_id].scheme.head_dim is not None:
weight_size = self.transforms[part_id].scheme.head_dim
value = value.unflatten(-1, (-1, weight_size))
value = ops.hadacore_transform(value)
value = value.flatten(-2, -1)
return value
# sylvester transforms are symmetric, inv => transpose => original
return ops.hadacore_transform(value)
# fall back to dense
else:
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
if self.transforms[part_id].scheme.head_dim is not None:
value = value.unflatten(-1, (-1, weight.size(0)))
value = dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
value = value.flatten(-2, -1)
return value
return dispatch_unquantized_gemm()(self, value.to(
weight.dtype), weight, None).to(value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
args: TransformArgs, input_size: int,
output_size: int) -> int:
if scheme.head_dim is not None:
return scheme.head_dim
if isinstance(layer, LinearBase):
if args.location == TransformLocation.INPUT:
return input_size
elif args.location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if args.location == TransformLocation.INPUT:
return output_size
elif args.location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
def _validate_input_transforms(self):
assert len(self.transforms) > 0
location = list(self.transforms.values())[0].args.location
if location == TransformLocation.INPUT:
first_data = self.weight.partitions[0].data
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, TransformTuple)
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
input_tfms: dict[int, TransformTuple]) -> bool:
return isinstance(
quant_scheme,
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
0].scheme.head_dim == quant_scheme.group_size
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, input_size, output_size,
params_dtype, **extra_weight_attrs):
# initializes fp4 qparams
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
ret = super().create_weights(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
assert self.input_transform is not None
assert len(self.input_transform.weight) == 1
assert self.input_transform.weight[0].size(
0) == layer.scheme.group_size
return ret
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError()

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
from compressed_tensors.transform import TransformArgs, TransformScheme
__all__ = ["TransformTuple"]
class TransformTuple(NamedTuple):
scheme_name: str
scheme: TransformScheme
args: TransformArgs

View File

@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor):
strides = x.stride()
sizes = x.shape
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
return is_transpose or is_not_transpose
@triton.jit
def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_SCALE_A: tl.constexpr,
BLOCK_SIZE_SCALE_B: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
accumulator_dtype = ACCUMULATOR_DTYPE
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
dtype=accumulator_dtype)
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
# eventually occur.
# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
masks_am = offsets_am < M
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
masks_bn = offsets_bn < N
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
offsets_a = (stride_am * offsets_am[:, None] +
stride_ak * offsets_k[None, :])
offsets_b = (stride_bk * offsets_k[:, None] +
stride_bn * offsets_bn[None, :])
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
# appropriate offsets and masks for each case. Same goes for
# BLOCK_SIZE_SCALE_B.
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
masks_scale_am = offsets_scale_am < M
offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
masks_scale_bn = offsets_scale_bn < N
a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b
scale_a_ptrs = scale_a_ptr + offsets_scale_am
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a)
masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
offsets_k += BLOCK_SIZE_K
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Apply scale at end.
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
# Need to broadcast to the appropriate size, if scale_a is already
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
# for scale_b below.
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
accumulator = scale_a * accumulator.to(tl.float32)
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
accumulator = scale_b.T * accumulator.to(tl.float32)
# Convert to output format.
c = accumulator.to(c_ptr.type.element_ty)
# Add bias, it's already in output format, so add it after conversion.
if bias_ptr:
offsets_bias = offsets_bn
bias_ptrs = bias_ptr + offsets_bias
bias_mask = offsets_bias < N
bias = tl.load(bias_ptrs, bias_mask)
c += bias
# Save output
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# input - [M, K]
# weight - [K, N]
def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32,
use_heuristic=True) -> torch.Tensor:
M, K = input.shape
N = weight.shape[1]
assert N > 0 and K > 0 and M > 0
assert weight.shape[0] == K
assert input.dtype == weight.dtype
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
or scale_a.shape[0] == M)
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
or scale_b.shape[0] == N)
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
assert is_weak_contiguous(weight)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
N, META['BLOCK_SIZE_N']), )
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
if use_heuristic:
is_small_N = N < 8192
next_power_of_2_M = max(32, triton.next_power_of_2(M))
if next_power_of_2_M <= 32:
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
elif next_power_of_2_M <= 64:
tile_shape = (64, 64, 256)
elif next_power_of_2_M <= 128:
tile_shape = (64, 128, 128)
else:
tile_shape = (128, 128, 128)
block_size_m, block_size_n, block_size_k = tile_shape
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
# A = input, B = weight, C = result
# A = M x K, B = K x N, C = M x N
scaled_mm_kernel[grid](input,
weight,
scale_a,
scale_b,
result,
bias,
M,
N,
K,
input.stride(0),
input.stride(1),
weight.stride(0),
weight.stride(1),
result.stride(0),
result.stride(1),
accumulator_dtype,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_SCALE_A=block_size_sa,
BLOCK_SIZE_SCALE_B=block_size_sb)
return result.to(out_dtype)

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Optional
import regex as re
from compressed_tensors import CompressionFormat
from torch.nn import Module
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
CompressionFormat.nvfp4_pack_quantized.value
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping))
if matched_target is None:
raise ValueError(
f"Unable to find matching target for {layer_name} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if _is_equal_or_regex_match(value,
target,
check_contains=check_contains):
return target
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
target_layers = ["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
None)
if fused is None:
return None
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# for each unfused component, find a match in targets
unfused_matches: list[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return unfused_matches[0] if all(unfused_matches) else None

View File

@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class DeepSpeedFPConfig(QuantizationConfig):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
Args:
weight_bits: the target quantization bits, 6 or 8.
group_size: group size for quantizaiton, default to 128.
"""
def __init__(
self,
weight_bits: int = 8,
group_size: int = 512,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16]
if self.weight_bits not in (6, 8):
raise ValueError(
"Currently, only 6-bit or 8-bit weight quantization are "
f"supported for DeepSpeed FP quantizaiton, but got "
f"{self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
f"group_size={self.group_size}")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "deepspeedfp"
@classmethod
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits=weight_bits, group_size=group_size)
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
return DeepSpeedFPLinearMethod(self)
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@staticmethod
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
]
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self)
return None
class DeepSpeedFPLinearMethod(LinearMethodBase):
"""Linear method for DeepSpeedFP quantizer.
Args:
quant_config: the DeepSpeedFP quantization config.
"""
def __init__(self, quant_config: DeepSpeedFPConfig):
self.quant_config = quant_config
self.weight = None
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
weight_loader=None,
**extra_weight_attrs):
del output_size
del input_size
output_size_per_partition = sum(output_partition_sizes)
weight = DeepSpeedFPParameter(
torch.Size((output_size_per_partition, input_size_per_partition)),
params_dtype=params_dtype,
quant_config=self.quant_config,
)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("weight", weight)
def quant_weight_loader(param, loaded_weight, *args, **kwargs):
# Calls the original weight loader (if any), quantizes the result,
# and then loads the quantized parameter.
if weight_loader is not None:
orig_param_data = param.data
param.data = param.ds_dequantize()
weight_loader(param, loaded_weight, *args, **kwargs)
param.data, loaded_weight = orig_param_data, param.data
param.ds_quantize_(loaded_weight.cuda())
extra_weight_attrs["weight_loader"] = quant_weight_loader
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)
class DeepSpeedFPParameter(nn.Parameter):
"""
DeepSpeedFP quantized parameter class that implements fp8/fp6
quantization deepspeed. Weights are stored in quantized form on
GPUs, and can be dequantized on-the-fly when needed by the model.
"""
def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
quant_config: DeepSpeedFPConfig):
try:
import deepspeed
if version.parse(deepspeed.__version__) < version.parse("0.14.2"):
raise ImportError("deepspeed version is wrong. Please "
"install deepspeed>=0.14.2.")
from deepspeed.ops.fp_quantizer import FP_Quantize
except ImportError as err:
raise ImportError("Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer.") from err
data = torch.empty((
orig_shape.numel() // quant_config.group_size,
quant_config.group_size * quant_config.weight_bits // 8 + 4,
),
dtype=torch.int8)
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
self.orig_shape = orig_shape
self.quant_config = quant_config
self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
self.fp_quantizer.orig_shape = orig_shape
self.fp_quantizer.orig_dtype = params_dtype
return self
def ds_quantize_(self, tensor: torch.Tensor):
assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
return self.data.copy_(
self.fp_quantizer.quantize(
tensor.data,
q_bits=self.quant_config.weight_bits,
))
def ds_dequantize(self, fp_out=None) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
return self.fp_quantizer.dequantize(
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
"""
Return a tensor where only the weights at `indices` are dequantized
(to save HBM -> SRAM bandwidth).
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
return self.fp_quantizer.selective_dequantize(
self.data,
indices,
fp_out=fp_out,
q_bits=self.quant_config.weight_bits)

View File

@@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
def __init__(self) -> None:
super().__init__()
@classmethod
def get_name(cls) -> QuantizationMethods:
return "experts_int8"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self, layer.moe_config)
return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
int8_dtype = torch.int8
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str, shard_id: int,
expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
shard_size = layer.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
device = get_tp_group().device
loaded_weight = loaded_weight.to(device)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == "w1":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
0])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == "w3":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, shard_size:2 *
shard_size].copy_(scales[:, 0])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == "w2":
scales = quantize_in_place_and_get_scales(loaded_weight[:,
shard])
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id,
expert_id)
return quantize_and_call_weight_loader
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
vmax = torch.iinfo(torch.int8).max
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
weight.div_(scales)
weight.round_()
weight.clamp_(-vmax, vmax)
return scales

View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: list[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = not current_platform.has_device_capability(89)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignore_list,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
self.out_dtype = torch.get_default_dtype()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE UPPER BOUND
input_scale_ub = torch.nn.Parameter(torch.tensor(
(self.quant_config.input_scale_ub), dtype=torch.float32),
requires_grad=False)
layer.input_scale_ub = input_scale_ub
def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight = layer.weight
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=None)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,599 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self,
unquantized_modules: Optional[list[str]] = None) -> None:
super().__init__()
self.unquantized_modules = unquantized_modules or []
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
if qweight_type in IMATRIX_QUANT_TYPES:
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
else:
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0:
return torch.empty(x.shape[0],
qweight.shape[0],
dtype=x.dtype,
device=x.device)
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
return y
def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
return torch.empty(x.shape[0],
qweight.shape[0],
dtype=x.dtype,
device=x.device)
try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
except AttributeError as error:
raise error
def _fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out
# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size)
out_hidden_states = torch.empty_like(x)
# unless we decent expert reuse we are better off running moe_vec kernel
if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
and x.shape[0] > 64):
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type, N, top_k,
num_tokens)
out = act(out)
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type2,
w2.shape[1], 1, num_tokens * top_k)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1))
ops.moe_sum(out, out_hidden_states)
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
num_tokens)
out = act(out)
out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
w2.shape[1], num_tokens * top_k)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1))
ops.moe_sum(out, out_hidden_states)
else:
logger.warning_once("There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. ")
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = w1[ii]
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out)
expert_down = w2[ii]
current_state = fused_mul_mat_gguf(out, expert_down,
qweight_type2).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
out_hidden_states[tok] = current_hidden_state
return out_hidden_states
def _fused_moe_gguf_fake(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
except AttributeError as error:
raise error
def _apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert (hidden_size == qweight.shape[1] // type_size * block_size)
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], dtype)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
except AttributeError as error:
raise error
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
self.params_dtype = params_dtype
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
"shard_id": [],
"shard_id_map": {},
})
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(
qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"shard_weight_type": {},
"ignore_warning": True
})
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)
def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES
or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in "
f"layer {layer}.")
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)
def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros((concat_side, padded_side),
dtype=dtype,
device=qweight.device)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(
x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param,
{"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_id = layer.qweight.shard_id
if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight
result = []
for idx in shard_id:
start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(),
qweight_type))
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
hidden_size)
#gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w13_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)
tensor_shape = (num_experts, intermediate_size_per_partition,
hidden_size)
#gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w2_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GGUFMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, activation)
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def embedding(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]
return apply_gguf_embedding(x,
qweight,
qweight_type,
hidden_size,
dtype=self.params_dtype)
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: list[torch.Tensor]

View File

@@ -0,0 +1,340 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
autoround_version: str = "",
modules_in_block_to_quantize: Optional[list[str]] = None,
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
super().__init__()
self.dynamic = dynamic
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
def __repr__(self) -> str:
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic, autoround_version, modules_in_block_to_quantize)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
if isinstance(layer, FusedMoE):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config
config = {
"quant_method": "gptq",
"bits": self.weight_bits,
"group_size": self.group_size,
"sym": True, # GPTQ typically uses symmetric quantization
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class ExllamaState(Enum):
UNUSED = enum.auto()
UNINITIALIZED = enum.auto()
READY = enum.auto()
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.exllama_state = exllama_state
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,448 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from packaging import version
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
BitBLASLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
check_bitblas_supported, verify_bitblas_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class GPTQBitBLASConfig(QuantizationConfig):
"""Config class for GPTQ BitBLAS"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
TORCH_DTYPE = torch.float16
GPTQ_CKPT_STORAGE_DTYPE = (
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
)
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# the gptq_bitblas prefer "quantized"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if version.parse(bitblas.__version__) < version.parse(
MINIMUM_BITBLAS_VERSION):
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
if c.isdigit()))
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})"
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gptq_bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
or user_quant == "gptq_bitblas")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_bitblas"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_bitblas for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQBitBLASLinearMethod(self)
return None
@property
def torch_storage_dtype(self) -> torch.dtype:
return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
# temporarily disable on ROCm platform
if not current_platform.is_cuda():
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies bitblas constraints.
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
sym)],
group_size=group_size)
class GPTQBitBLASLinearMethod(LinearMethodBase):
"""Linear method for GPTQ BitBLAS.
Args:
quant_config: The GPTQ BitBLAS quantization config.
"""
kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing
quantized weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_partition_sizes: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the input
size per partition is not divisible by the group size
in `quant_config`.
"""
if params_dtype != torch.float16:
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
if input_size_per_partition % group_size != 0:
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({self.quant_config.group_size})."
)
kernel_type = self.kernel_type
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQBitBLASLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Determine sharding
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Init buffers
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
# Ignore warning from fused linear layers such as QKVParallelLinear.
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx",
bitblas_quant_config=self.quant_config,
)
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self.kernel.configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
bias=False,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
if bias is not None:
out.add_(bias)
return out

View File

@@ -0,0 +1,751 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override, get_linear_quant_method, override_config)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
logger = init_logger(__name__)
def get_moe_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
moe_method_cls: type,
):
cloned_config = deepcopy(config)
if isinstance(layer, FusedMoE):
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
return UnquantizedFusedMoEMethod(layer.moe_config)
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return moe_method_cls(cloned_config, layer.moe_config)
return None
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self.dynamic = dynamic
self.weight_bits = weight_bits
self.is_sym = is_sym
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")
def __repr__(self) -> str:
return (
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic, config,
modules_in_block_to_quantize)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "gptq_marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return get_moe_quant_method(self, layer, prefix,
GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if not current_platform.is_cuda():
return False
if quant_method != "gptq":
return False
# Marlin conversion is only valid if required properties are found
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)
def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
else:
raise ValueError(
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = (intermediate_size_full
if self.quant_config.desc_act else
intermediate_size_per_partition)
scales_size2 = (w2_scales_size // self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": True
})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition //
self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_scales,
{"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# don't shard the w2 scales when running act order
set_weight_attrs(w2_qzeros,
{"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] *
(self.quant_config.group_size if self.quant_config.group_size != -1
else self.quant_config.pack_factor),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full)

View File

@@ -0,0 +1,297 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
"""
def __init__(
self,
weight_bits: int,
group_size: int,
) -> None:
super().__init__()
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size
# Verify
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin_24"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_marlin_24")
if is_marlin_24_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. "
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self)
return None
class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24.
Args:
quant_config: The Marlin24 quantization config.
"""
def __init__(self, quant_config: GPTQMarlin24Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
# Meta
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)
layer.register_parameter("B_24", qweight)
layer.register_parameter("B_meta", meta)
layer.register_parameter("s", scales)
layer.register_parameter("workspace", workspace)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B_24
meta = layer.B_meta
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.quant_type,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,333 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
from vllm.model_executor.parameter import (BasevLLMParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class HQQMarlinConfig(QuantizationConfig):
"""Config class for HQQ Marlin"""
def __init__(
self,
weight_bits: int,
group_size: int,
skip_modules: Optional[list[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization "
"bitsize is currently 4.")
self.weight_bits = weight_bits
self.group_size = group_size
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
self.quant_type = scalar_types.uint4
self.skip_modules = skip_modules
def __repr__(self) -> str:
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "hqq"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"])
skip_modules = config["skip_modules"]
return cls(weight_bits, group_size, skip_modules)
def is_layer_skipped(self, prefix: str) -> bool:
# Split the prefix into its dot-separated components
components = prefix.split('.')
# Check if any of the skip modules exactly matches any component
return self.skip_modules is not None and any(
module_name in components for module_name in self.skip_modules)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self.is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return HQQMarlinMethod(self)
return None
# Empty HQQ parameter, will be ignored during loading
class HQQEmptyParameter(BasevLLMParameter):
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
pass
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
raise ValueError("No loader provided for HQQ parameter!")
# HQQ packing creates issues with sharding - therefore, prior to loading, we
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
class HQQweightParameter(PackedvLLMParameter):
# unpack function from https://github.com/mobiusml/hqq
def unpack_4bit_u8(self,
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
dtype = torch.uint8
step = W_q.shape[0]
tmp = torch.empty([2 * step, W_q.shape[1]],
dtype=dtype,
device=W_q.device)
tmp[:step] = (W_q & 0b11110000) >> 4
tmp[step:] = W_q & 0b00001111
return tmp
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
**kwargs):
super().__init__(packed_factor, packed_dim, None, **kwargs)
self.weight_bits = weight_bits
self.input_shape = self.shape[self.input_dim] * self.packed_factor
self.output_shape = self.shape[self.output_dim]
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(self.output_shape,
-1).transpose(1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_row_parallel_weight(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
# GPTQ shape (transposed - we transpose them too when processing weights).
class HQQZeroScaleParameter(GroupQuantScaleParameter):
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
super().load_row_parallel_weight(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)
class HQQMarlinMethod(LinearMethodBase):
"""Linear method for HQQ Marlin.
"""
def __init__(
self,
quant_config: HQQMarlinConfig,
):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
self.output_size_per_partition = sum(output_partition_sizes)
self.input_size_per_partition = input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
self.scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)
qweight = HQQweightParameter(
data=torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_bits=self.quant_config.weight_bits,
weight_loader=weight_loader)
zeros = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
scales = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("W_q", qweight)
layer.register_parameter("zero", zeros)
layer.register_parameter("scale", scales)
# Ignore extra parameters in the HQQ model.
# To be added as needed.
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
"encoded_state_dict", "group_size", "nbits",
"offload_meta", "optimize", "packing",
"quant_scale", "quant_zero", "round_zero",
"shape", "stores_quant_config",
"unpack_view_dtype", "view_as_float")
for name in ignore_parameters:
layer.register_parameter(
name,
HQQEmptyParameter(data=torch.empty(0),
weight_loader=weight_loader))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
dev = layer.W_q.device
# Repack to Marlin
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(
layer.W_q,
sort_indices,
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.weight_bits,
).to(dev)
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)
layer.g_idx = marlin_make_empty_g_idx(dev)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
layer.marlin_qweight = marlin_w_q
layer.marlin_zeros = marlin_zp
layer.marlin_scales = marlin_s
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(self.output_size_per_partition,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
scales = layer.marlin_scales
zeros = layer.marlin_zeros
orig_type = x.dtype
if orig_type != torch.float16:
x = x.to(torch.float16)
scales = scales.to(torch.float16)
zeros = zeros.to(torch.float16)
marlin_out = ops.gptq_marlin_gemm(
x,
None,
layer.marlin_qweight,
bias,
scales,
None,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
scalar_types.uint4,
x.shape[0],
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
False, # use atomic add
True, # use 32-bit reduce
True, # use float zp
)
if orig_type != torch.float16:
marlin_out = marlin_out.to(orig_type)
return marlin_out

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Intel Gaudi supports quantization of various modules and functions,
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
# During model loading,
# INC will patch layers with quantization/dequantization operators.
# Meanwhile, INC will convert original weight to target datatype
# and loading to target device.
# static scaling should be provided through Quant_CONFIG:
# `QUANT_CONFIG` is an environment variable,
# that points to the measurement or quantization JSON config file.
# The measurement configuration file is used during the calibration procedure,
# to collect measurements for a given model.
# The quantization configuration is used during inference.
# For more information, please refer to:
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
from typing import Any, Optional
import torch
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
class INCConfig(QuantizationConfig):
"""Config class for FP8 using Intel Neural Compressor."""
@classmethod
def get_name(cls) -> QuantizationMethods:
return "inc"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
raise AssertionError
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return None
@classmethod
def get_min_capability(cls) -> int:
raise AssertionError
@staticmethod
def get_config_filenames() -> list[str]:
return []

View File

@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
This CustomOp supports both static and dynamic quantization.
"""
def __init__(self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None,
column_major_scales: bool = False):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
column major format
"""
super().__init__()
self.static = static
self.group_shape = group_shape
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
def forward_cuda(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8(
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE)
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
return ops.scaled_fp8_quant(
x,
scale,
num_token_padding=self.num_token_padding,
scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
def forward_native(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
return self._quantize_group_native(x)
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
if scale is None:
if self.group_shape == GroupShape.PER_TOKEN:
x_max, _ = x.abs().max(dim=-1)
x_max = x_max.unsqueeze(-1).to(torch.float32)
if scale_ub is not None:
x_max = x_max.clamp(max=scale_ub)
else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal()
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation.
# Fortunately, we don't use padding if compiling.
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
# in general.
if self.num_token_padding is not None:
padding = max(self.num_token_padding - out.size(0), 0)
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
return out, scale
def _quantize_group_native(
self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
orig_shape = x.shape
hidden_dim = x.shape[-1]
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
padded_dim = num_groups * self.group_size
if padded_dim != hidden_dim:
padding = padded_dim - hidden_dim
x = F.pad(x, (0, padding), mode='constant', value=0.0)
x_grouped = x.view(-1, num_groups, self.group_size)
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
x_scaled = x_grouped / scales
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
x_quant = x_quant.view(-1, padded_dim)
if padded_dim != hidden_dim:
x_quant = x_quant[..., :hidden_dim]
x_quant = x_quant.view(orig_shape)
scales = scales.squeeze(-1)
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
if self.column_major_scales:
scales = scales.transpose(-2, -1).contiguous()
return x_quant, scales

View File

@@ -0,0 +1,415 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
import torch
from packaging import version
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.6.0"
class IPEXConfig(QuantizationConfig):
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
including AWQ, GPTQ.
"""
IPEX_QUANT_METHOD_MAP = {
"awq": 1,
"gptq": 0,
}
def __init__(
self,
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[list[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
super().__init__()
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
self.modules_to_not_convert = modules_to_not_convert or []
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
if self.weight_bits not in [4]:
raise ValueError(f"IPEX quantization supports weight bits [4], "
f"but got {self.weight_bits}.")
if self.method not in ["awq", "gptq"]:
raise ValueError(f"IPEX quantization supports [awq, gptq], "
f"but got {self.method}.")
def __repr__(self) -> str:
return (f"IPEXConfig(method={self.method},"
f"weight_bits={self.weight_bits}, "
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "ipex"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
def get_min_capability(cls) -> int:
return -1
@staticmethod
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower()
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config,
["q_group_size", "group_size"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(method, weight_bits, group_size, modules_to_not_convert,
False, False)
# otherwise for gptq
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
return cls(method, weight_bits, group_size, [], desc_act,
lm_head_quantized)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None
quant_method = hf_quant_cfg.get("quant_method", "").lower()
if quant_method in ["awq", "gptq"]:
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":
return IPEXGPTQLinearMethod(self)
return None
class IPEXGPTQLinearMethod(GPTQLinearMethod):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
"""
def __init__(self, quant_config: IPEXConfig):
self.quant_config = quant_config # type: ignore
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if version.parse(
ipex.__version__) < version.parse(MIN_IPEX_VERSION):
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.quant_config.group_size,
)
layer.ipex_output_size = layer.qweight.shape[-1]
g_idx = layer.g_idx if self.quant_config.desc_act else None
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU/XPU backend.
"""
def __init__(self, quant_config: IPEXConfig):
self.quant_config = quant_config # type: ignore
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer=layer)
bias = layer.bias if not layer.skip_bias_add else None
try:
import intel_extension_for_pytorch as ipex
if version.parse(
ipex.__version__) < version.parse(MIN_IPEX_VERSION):
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.quant_config.group_size,
)
layer.ipex_output_size = layer.qweight.size(
1) * self.quant_config.pack_factor
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight.data
weight_scale = layer.weight_scale.data
output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True,
weight_scale, bias)
return output
class XPUFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.quant_config = quant_config
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT_SCALES
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
a1_scale_inv=layer.w13_input_scale,
a2_scale_inv=layer.w2_input_scale,
use_prepack=True,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
)

Some files were not shown because too many files have changed in this diff Show More