init
This commit is contained in:
157
vllm/model_executor/layers/quantization/__init__.py
Normal file
157
vllm/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, get_args
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||
|
||||
|
||||
def register_quantization_config(quantization: str):
|
||||
"""Register a customized vllm quantization config.
|
||||
|
||||
When a quantization method is not supported by vllm, you can register a customized
|
||||
quantization config to support it.
|
||||
|
||||
Args:
|
||||
quantization (str): The quantization method name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
>>> from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
>>>
|
||||
>>> @register_quantization_config("my_quant")
|
||||
... class MyQuantConfig(QuantizationConfig):
|
||||
... pass
|
||||
>>>
|
||||
>>> get_quantization_config("my_quant")
|
||||
<class 'MyQuantConfig'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(quant_config_cls):
|
||||
if quantization in QUANTIZATION_METHODS:
|
||||
raise ValueError(
|
||||
f"The quantization method `{quantization}` is already exists.")
|
||||
if not issubclass(quant_config_cls, QuantizationConfig):
|
||||
raise ValueError("The quantization config must be a subclass of "
|
||||
"`QuantizationConfig`.")
|
||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
|
||||
QUANTIZATION_METHODS.append(quantization)
|
||||
return quant_config_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
# lazy import to avoid triggering `torch.compile` too early
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
from .auto_round import AutoRoundConfig
|
||||
from .awq import AWQConfig
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
from .bitblas import BitBLASConfig
|
||||
from .bitsandbytes import BitsAndBytesConfig
|
||||
from .compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig)
|
||||
from .deepspeedfp import DeepSpeedFPConfig
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
from .fp8 import Fp8Config
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_bitblas import GPTQBitBLASConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .hqq_marlin import HQQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .rtn import RTNConfig
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"bitblas": BitBLASConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq_bitblas": GPTQBitBLASConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"hqq": HQQMarlinConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
"rtn": RTNConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
return method_to_config[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"QuantizationMethods",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
388
vllm/model_executor/layers/quantization/auto_round.py
Normal file
388
vllm/model_executor/layers/quantization/auto_round.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AutoRoundConfig(QuantizationConfig):
|
||||
"""Config class for AutoRound.
|
||||
Reference: https://arxiv.org/pdf/2309.05516
|
||||
"""
|
||||
|
||||
SUPPORTED_BITS = {2, 3, 4, 8}
|
||||
SUPPORTED_DTYPES = {"int"}
|
||||
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
||||
SUPPORTED_BACKENDS = {
|
||||
"auto",
|
||||
"gptq",
|
||||
"gptq:marlin",
|
||||
"awq",
|
||||
"awq:marlin",
|
||||
"marlin",
|
||||
"ipex",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
sym: bool = True,
|
||||
packing_format: str = "auto_round:auto_gptq",
|
||||
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
|
||||
extra_config: Optional[dict[str, Any]] = None,
|
||||
data_type: str = "int",
|
||||
backend: str = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if weight_bits not in self.SUPPORTED_BITS:
|
||||
raise ValueError(f"Unsupported weight_bits: {weight_bits}, "
|
||||
f"currently only support {self.SUPPORTED_BITS}")
|
||||
if data_type not in self.SUPPORTED_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported data_type: {data_type},"
|
||||
f" currently only support {self.SUPPORTED_DTYPES}")
|
||||
if packing_format not in self.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported packing_format: {packing_format}, "
|
||||
f"currently only support {self.SUPPORTED_FORMATS}")
|
||||
if backend not in self.SUPPORTED_BACKENDS:
|
||||
raise ValueError(
|
||||
f"Unsupported backend: {backend}, "
|
||||
f"currently only support {self.SUPPORTED_BACKENDS}")
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.sym = sym
|
||||
self.packing_format = packing_format
|
||||
self.block_name_to_quantize = (block_name_to_quantize.split(",") if
|
||||
isinstance(block_name_to_quantize, str)
|
||||
else block_name_to_quantize)
|
||||
self.extra_config = extra_config
|
||||
self.data_type = data_type
|
||||
self.backend = backend
|
||||
self.pack_factor = Fraction(32, weight_bits)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AutoRoundConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, sym={self.sym})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "auto-round"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantization_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
|
||||
return cls(
|
||||
weight_bits=cls.get_from_keys(config, ["bits"]),
|
||||
group_size=cls.get_from_keys(config, ["group_size"]),
|
||||
sym=cls.get_from_keys(config, ["sym"]),
|
||||
packing_format=cls.get_from_keys_or(config, ["packing_format"],
|
||||
"auto_round:auto_gptq"),
|
||||
block_name_to_quantize=cls.get_from_keys_or(
|
||||
config, ["block_name_to_quantize", "to_quant_block_names"],
|
||||
None),
|
||||
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
|
||||
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
|
||||
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"],
|
||||
"auto"),
|
||||
)
|
||||
|
||||
def get_layer_config(self, layer, layer_name: str):
|
||||
|
||||
def get_config(name: str, quantized: bool = True):
|
||||
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
|
||||
return (
|
||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||
cfg.get("group_size", self.group_size if quantized else -1),
|
||||
cfg.get("sym", self.sym if quantized else True),
|
||||
)
|
||||
|
||||
# 1. Exact match from config
|
||||
if self.extra_config and layer_name in self.extra_config:
|
||||
return get_config(layer_name)
|
||||
|
||||
# 2. Determine whether layer should be quantized
|
||||
quantized = not isinstance(layer, ParallelLMHead)
|
||||
if self.block_name_to_quantize:
|
||||
quantized = any(
|
||||
layer_name.startswith(name)
|
||||
for name in self.block_name_to_quantize)
|
||||
|
||||
# 3. Handle fused MoE
|
||||
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(
|
||||
):
|
||||
moe_configs = [
|
||||
get_config(name, quantized) for name in self.extra_config
|
||||
if name.startswith(layer_name)
|
||||
]
|
||||
if moe_configs:
|
||||
if len(set(moe_configs)) == 1:
|
||||
return moe_configs[0]
|
||||
raise ValueError(f"Fused MoE layer '{layer_name}' requires "
|
||||
f"consistent quant config for all sub-layers")
|
||||
|
||||
# 4. Handle fused QKV or other patterns
|
||||
if self.extra_config:
|
||||
for fusion_key, sub_keys in self.packed_modules_mapping.items():
|
||||
if fusion_key in layer_name and layer_name.count(
|
||||
fusion_key) == 1:
|
||||
sub_names = [
|
||||
layer_name.replace(fusion_key, sub_key)
|
||||
for sub_key in sub_keys
|
||||
]
|
||||
sub_configs = [
|
||||
get_config(name, quantized) for name in sub_names
|
||||
]
|
||||
if len(set(sub_configs)) == 1:
|
||||
return sub_configs[0]
|
||||
raise ValueError(
|
||||
f"Fused module '{layer_name}' requires "
|
||||
f"consistent quant config for {sub_names}")
|
||||
|
||||
# 5. Fallback
|
||||
return get_config(layer_name, quantized)
|
||||
|
||||
def check_quantized(self, weight_bits: int) -> bool:
|
||||
return weight_bits < 16
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.block_name_to_quantize is not None:
|
||||
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.block_name_to_quantize)
|
||||
if self.extra_config is not None:
|
||||
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
|
||||
|
||||
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
AWQ_TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
use_marlin = (weight_bits
|
||||
in AWQ_TYPE_MAP) and check_marlin_supported(
|
||||
AWQ_TYPE_MAP[weight_bits], group_size, not sym)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size)
|
||||
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
|
||||
|
||||
quant_args_marlin = AWQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
lm_head_quantized=False,
|
||||
full_config={},
|
||||
modules_to_not_convert=[],
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQConfig, AWQLinearMethod)
|
||||
|
||||
quant_args = AWQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
zero_point=not sym,
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"zero_point": not sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return AWQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return AWQLinearMethod(quant_args)
|
||||
return None
|
||||
|
||||
def apply_gptq_quant_layer(self,
|
||||
layer,
|
||||
prefix: str,
|
||||
backend: str = "auto"):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer)
|
||||
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
||||
prefix,
|
||||
layer.__class__.__name__,
|
||||
weight_bits,
|
||||
group_size,
|
||||
sym,
|
||||
)
|
||||
if backend == "auto" or "marlin" in backend:
|
||||
GPTQ_TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
use_marlin = (weight_bits,
|
||||
sym) in GPTQ_TYPE_MAP and check_marlin_supported(
|
||||
GPTQ_TYPE_MAP[(weight_bits, sym)],
|
||||
group_size,
|
||||
has_zp=not sym)
|
||||
if isinstance(layer, FusedMoE):
|
||||
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
||||
layer, group_size)
|
||||
else:
|
||||
use_marlin = False
|
||||
if use_marlin:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
|
||||
|
||||
quant_args_marlin = GPTQMarlinConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
full_config={},
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.gptq import (
|
||||
GPTQConfig, GPTQLinearMethod)
|
||||
|
||||
quant_args = GPTQConfig(
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
lm_head_quantized=False,
|
||||
desc_act=False,
|
||||
dynamic={},
|
||||
)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": weight_bits,
|
||||
"group_size": group_size,
|
||||
"sym": sym,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if use_marlin:
|
||||
return GPTQMarlinLinearMethod(quant_args_marlin)
|
||||
else:
|
||||
return GPTQLinearMethod(quant_args)
|
||||
|
||||
return None
|
||||
|
||||
def apply_ipex_quant_layer(self, layer, prefix: str):
|
||||
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
||||
if not self.check_quantized(weight_bits):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
else:
|
||||
return None
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
|
||||
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if "awq" in self.packing_format:
|
||||
config = IPEXConfig(method="awq",
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size)
|
||||
return IPEXAWQLinearMethod(config)
|
||||
elif "gptq" in self.packing_format:
|
||||
config = IPEXConfig(method="gptq",
|
||||
weight_bits=weight_bits,
|
||||
group_size=group_size)
|
||||
return IPEXGPTQLinearMethod(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ipex backend only supports awq "
|
||||
f"and gtpq format,but got {self.packing_format}")
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
if (current_platform.is_cpu() or current_platform.is_xpu()
|
||||
or self.backend == "ipex"):
|
||||
return self.apply_ipex_quant_layer(layer, prefix)
|
||||
if "gptq" in self.packing_format or "gptq" in self.backend:
|
||||
return self.apply_gptq_quant_layer(layer, prefix)
|
||||
if "awq" in self.packing_format or "awq" in self.backend:
|
||||
return self.apply_awq_quant_layer(layer, prefix)
|
||||
228
vllm/model_executor/layers/quantization/awq.py
Normal file
228
vllm/model_executor/layers/quantization/awq.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
# Lazy import to avoid circular import.
|
||||
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .utils.marlin_utils import check_moe_marlin_supports_layer
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
config = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
marlin_compatible_config_dict = {
|
||||
"quant_method": "awq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"zero_point": self.zero_point,
|
||||
"lm_head": False,
|
||||
"modules_to_not_convert": self.modules_to_not_convert,
|
||||
}
|
||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||
marlin_compatible_config_dict)
|
||||
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
554
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
554
vllm/model_executor/layers/quantization/awq_marlin.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
f"Supported num_bits = {self.TYPE_MAP.keys()}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[self.weight_bits]
|
||||
|
||||
verify_marlin_supported(self.quant_type,
|
||||
group_size=self.group_size,
|
||||
has_zp=self.zero_point)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
|
||||
modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "awq_marlin")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_marlin"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_marlin for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if the layer is supported by AWQMarlin.
|
||||
if not check_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
|
||||
prefix,
|
||||
)
|
||||
return AWQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if is_layer_skipped_awq(
|
||||
prefix, getattr(self, "modules_to_not_convert", [])):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or zero_point is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=zero_point)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
|
||||
# TODO: Update this docs
|
||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_awq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: AWQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed":
|
||||
True,
|
||||
"quant_method":
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
})
|
||||
|
||||
w13_qweight = Parameter(
|
||||
torch.empty(num_experts,
|
||||
hidden_size,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
w2_qweight = Parameter(torch.empty(num_experts,
|
||||
intermediate_size_per_partition,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
intermediate_size_per_partition * 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_ZERO_POINT
|
||||
# Allocate 2 zero points for w1 and w3 respectively.
|
||||
w13_qzeros = Parameter(
|
||||
torch.empty(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = Parameter(torch.empty(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.awq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
# Why does this take the intermediate size for size_k?
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w13_qzeros,
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
marlin_w2_zp = moe_awq_to_marlin_zero_points(
|
||||
layer.w2_qzeros,
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `AWQMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace)
|
||||
320
vllm/model_executor/layers/quantization/awq_triton.py
Normal file
320
vllm/model_executor/layers/quantization/awq_triton.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_dequantize_kernel(
|
||||
qweight_ptr, # quantized matrix
|
||||
scales_ptr, # scales, per group
|
||||
zeros_ptr, # zeros, per group
|
||||
group_size, # Should always be one of the supported group sizes
|
||||
result_ptr, # Output matrix
|
||||
num_cols, # input num cols in qweight
|
||||
num_rows, # input num rows in qweight
|
||||
BLOCK_SIZE_X: tl.constexpr,
|
||||
BLOCK_SIZE_Y: tl.constexpr):
|
||||
# Set up the pids.
|
||||
pid_x = tl.program_id(axis=0)
|
||||
pid_y = tl.program_id(axis=1)
|
||||
|
||||
# Compute offsets and masks for qweight_ptr.
|
||||
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
|
||||
|
||||
masks_y = offsets_y < num_rows
|
||||
masks_x = offsets_x < num_cols
|
||||
|
||||
masks = masks_y[:, None] & masks_x[None, :]
|
||||
|
||||
# Compute offsets and masks for result output ptr.
|
||||
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(
|
||||
0, BLOCK_SIZE_X * 8)
|
||||
result_offsets = (8 * num_cols * result_offsets_y[:, None] +
|
||||
result_offsets_x[None, :])
|
||||
|
||||
result_masks_y = result_offsets_y < num_rows
|
||||
result_masks_x = result_offsets_x < num_cols * 8
|
||||
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
||||
|
||||
# Load the weights.
|
||||
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
|
||||
tl.arange(0, 4)[:, None]).reshape(8)
|
||||
|
||||
# Use this to compute a set of shifts that can be used to unpack and
|
||||
# reorder the values in iweights and zeros.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
iweights = (iweights >> shifts) & 0xF
|
||||
|
||||
# Compute zero offsets and masks.
|
||||
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
|
||||
|
||||
zero_masks_y = zero_offsets_y < num_rows // group_size
|
||||
zero_masks_x = zero_offsets_x < num_cols
|
||||
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
||||
|
||||
# Load the zeros.
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
|
||||
# Compute scale offsets and masks.
|
||||
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
|
||||
tl.arange(0, BLOCK_SIZE_X * 8))
|
||||
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
|
||||
scale_offsets_x[None, :])
|
||||
scale_masks_y = scale_offsets_y < num_rows // group_size
|
||||
scale_masks_x = scale_offsets_x < num_cols * 8
|
||||
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
||||
|
||||
# Load the scales.
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Dequantize.
|
||||
iweights = (iweights - zeros) * scales
|
||||
iweights = iweights.to(result_ptr.type.element_ty)
|
||||
|
||||
# Finally, store.
|
||||
tl.store(result_ptr + result_offsets, iweights, result_masks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
group_size, BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_z = tl.program_id(1)
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = c_ptr.type.element_ty
|
||||
|
||||
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
||||
# accumulator = tl.arange(0, BLOCK_SIZE_N)
|
||||
# accumulator = tl.broadcast_to(accumulator[None, :],
|
||||
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
|
||||
# accumulator = accumulator & 0x0
|
||||
# accumulator = accumulator.to(accumulator_dtype)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
|
||||
dtype=accumulator_dtype)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
|
||||
tl.arange(0, 4)[:, None]).reshape(8)
|
||||
|
||||
# Create the necessary shifts to use to unpack.
|
||||
shifts = reverse_awq_order_tensor * 4
|
||||
shifts = tl.broadcast_to(shifts[None, :],
|
||||
(BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
|
||||
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_bn = offsets_bn < N // 8
|
||||
|
||||
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_zn = offsets_zn < N // 8
|
||||
|
||||
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
masks_sn = offsets_sn < N
|
||||
|
||||
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
|
||||
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
|
||||
# block_offset = BLOCK_SIZE_K * SPLIT_K
|
||||
# for k in range(0, (K + block_offset - 1) // (block_offset)):
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
|
||||
# Dequantize b.
|
||||
offsets_szk = (
|
||||
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
|
||||
tl.arange(0, 1))
|
||||
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
|
||||
masks_zk = offsets_szk < K // group_size
|
||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||
zeros_ptrs = zeros_ptr + offsets_z
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
|
||||
masks_sk = offsets_szk < K // group_size
|
||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||
scales_ptrs = scales_ptr + offsets_s
|
||||
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
b = (b >> shifts) & 0xF
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
b = (b - zeros) * scales
|
||||
b = b.to(c_ptr.type.element_ty)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K * SPLIT_K
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
|
||||
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# qweights - [K , M // 8], int32
|
||||
# scales - [K // G, M ], float16
|
||||
# zeros - [K // G, M // 8], int32
|
||||
def awq_dequantize_triton(qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
block_size_x: int = 32,
|
||||
block_size_y: int = 32) -> torch.Tensor:
|
||||
K = qweight.shape[0]
|
||||
M = scales.shape[1]
|
||||
group_size = qweight.shape[0] // scales.shape[0]
|
||||
|
||||
assert K > 0 and M > 0
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == M
|
||||
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
# Result tensor:
|
||||
# number of rows = same as input tensor
|
||||
# number of cols = 8 x input tensor num cols
|
||||
result = torch.empty(qweight.shape[0],
|
||||
qweight.shape[1] * 8,
|
||||
device=qweight.device,
|
||||
dtype=scales.dtype)
|
||||
|
||||
Y = qweight.shape[0] # num rows
|
||||
X = qweight.shape[1] # num cols
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(X, META['BLOCK_SIZE_X']),
|
||||
triton.cdiv(Y, META['BLOCK_SIZE_Y']),
|
||||
)
|
||||
awq_dequantize_kernel[grid](qweight,
|
||||
scales,
|
||||
zeros,
|
||||
group_size,
|
||||
result,
|
||||
X,
|
||||
Y,
|
||||
BLOCK_SIZE_X=block_size_x,
|
||||
BLOCK_SIZE_Y=block_size_y)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# qweight - [K, N // 8]
|
||||
# qzeros - [K // G, N // 8]
|
||||
# scales - [K // G, N]
|
||||
# split_k_iters - parallelism along K-dimension, int, power of 2.
|
||||
def awq_gemm_triton(input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: int,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = qweight.shape[1] * 8
|
||||
group_size = qweight.shape[0] // qzeros.shape[0]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
|
||||
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
|
||||
assert scales.shape[0] == K // group_size and scales.shape[1] == N
|
||||
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
|
||||
assert split_k_iters <= 32
|
||||
assert group_size <= K
|
||||
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
N, META['BLOCK_SIZE_N']),
|
||||
split_k_iters,
|
||||
)
|
||||
|
||||
result = torch.zeros((split_k_iters, M, N),
|
||||
dtype=scales.dtype,
|
||||
device=input.device)
|
||||
|
||||
# A = input, B = qweight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
awq_gemm_kernel[grid](input,
|
||||
qweight,
|
||||
result,
|
||||
qzeros,
|
||||
scales,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
group_size,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
SPLIT_K=split_k_iters)
|
||||
|
||||
result = result.sum(0)
|
||||
|
||||
return result
|
||||
170
vllm/model_executor/layers/quantization/base_config.py
Normal file
170
vllm/model_executor/layers/quantization/base_config.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, *weight_args,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Not required functions
|
||||
def embedding(self, layer: torch.nn.Module, *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Gather embeddings in the layer based on indices in the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
has been changed from the base implementation.
|
||||
"""
|
||||
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
||||
None)
|
||||
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
||||
|
||||
return (class_embedding is not None
|
||||
and class_embedding is not base_embedding)
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
this method should only be overwritten by subclasses in exceptional
|
||||
circumstances
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str],
|
||||
default: Any) -> Any:
|
||||
"""Get an optional value from the model's quantization config."""
|
||||
try:
|
||||
return QuantizationConfig.get_from_keys(config, keys)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional[QuantizeMethodBase]:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
layer: The layer for the quant method.
|
||||
prefix: The full name of the layer in the state dict
|
||||
Returns:
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper( # noqa: B027
|
||||
self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
"""
|
||||
Interface for models to update module names referenced in
|
||||
quantization configs in order to reflect the vllm model structure
|
||||
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
# TODO (@kylesayrs): add implementations for all subclasses
|
||||
pass
|
||||
|
||||
def maybe_update_config(self, model_name: str): # noqa: B027
|
||||
"""
|
||||
Interface to update values after config initialization.
|
||||
"""
|
||||
pass
|
||||
464
vllm/model_executor/layers/quantization/bitblas.py
Normal file
464
vllm/model_executor/layers/quantization/bitblas.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
|
||||
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASConfig(QuantizationConfig):
|
||||
"""Config class for BitBLAS.
|
||||
|
||||
Reference: https://github.com/Microsoft/BitBLAS
|
||||
"""
|
||||
TORCH_DTYPE = torch.float16
|
||||
STORAGE_DTYPE = "int8" # assume int8 storage
|
||||
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# gptq_with_bitblas prefer "quantized implementation"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: Optional[int],
|
||||
desc_act: Optional[bool],
|
||||
is_sym: Optional[bool],
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
try:
|
||||
import bitblas
|
||||
if version.parse(bitblas.__version__) < version.parse(
|
||||
MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
storage_dtype = self.STORAGE_DTYPE
|
||||
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: dict[str, Any],
|
||||
keys: list[str],
|
||||
default: Any = None) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"], -1)
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"], False)
|
||||
is_sym = cls.get_from_keys(config, ["sym"], False)
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||
or hf_quant_cfg.get("is_bitblas_format", False))
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "bitblas")
|
||||
|
||||
if is_bitblas_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. Using {} kernel.".
|
||||
format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return BitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class BitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The BitBLAS quantization config.
|
||||
"""
|
||||
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
|
||||
# Instead of BITBLAS_OPTIMIZE_FEATURES
|
||||
# If you want to high contiguous batching
|
||||
# performance
|
||||
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING = True
|
||||
BITBLAS_DTYPES = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
|
||||
def __init__(self, quant_config: BitBLASConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing quantized
|
||||
weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_partition_sizes: List of output partition sizes.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or if the input
|
||||
size per partition is not divisible by the group size
|
||||
in `quant_config`.
|
||||
"""
|
||||
del input_size, output_size # Unused arguments.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype not in self.quant_config.get_supported_act_dtypes():
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
group_size = self.quant_config.group_size
|
||||
if group_size is None:
|
||||
group_size = -1
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (group_size != -1 and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({group_size}).")
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self._configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
enable_tuning=self.ENABLE_TUNING,
|
||||
bias=False,
|
||||
layout="nt",
|
||||
bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
# Initialize quantized weights with dimensions
|
||||
# Quantized 4Bit weights packed.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
self.bitblas_matmul.retrieve_weight_shape(),
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
|
||||
if self.bitblas_matmul.propagate_b else None),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Compute the number of input groups for channel-wise quantization.
|
||||
input_groups = (1 if group_size == -1 else input_size_per_partition //
|
||||
group_size)
|
||||
|
||||
# Initialize scales and zeros for the quantized weights.
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
if self.quant_config.zeros_mode == "quantized":
|
||||
zeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
else:
|
||||
zeros = BasevLLMParameter(
|
||||
torch.empty(output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
# Set attributes to indicate how scales and zeros are applied.
|
||||
set_weight_attrs(zeros, {
|
||||
"input_dim": None if input_groups == 1 else 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("zeros", zeros)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.create_weights_gptq(layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size,
|
||||
output_size, params_dtype,
|
||||
**extra_weight_attrs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
out_dtype="float16",
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = self.quant_config.group_size
|
||||
zeros_mode = self.quant_config.zeros_mode
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if self.quant_config.is_sym:
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=out_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=self.quant_config.STORAGE_DTYPE,
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNED_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNED_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created."
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = (
|
||||
f"BitBLAS Operator {config} found in global_operator_cache.")
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.zeros
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
if self.quant_config.is_sym:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
|
||||
else:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
|
||||
def apply(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.apply_gptq(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
627
vllm/model_executor/layers/quantization/bitsandbytes.py
Normal file
627
vllm/model_executor/layers/quantization/bitsandbytes.py
Normal file
@@ -0,0 +1,627 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
"""Config class for BitsAndBytes Quantization.
|
||||
|
||||
Reference: https://arxiv.org/abs/2305.14314
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_in_8bit: bool = False,
|
||||
load_in_4bit: bool = True,
|
||||
bnb_4bit_compute_dtype: str = "float32",
|
||||
bnb_4bit_quant_storage: str = "uint8",
|
||||
bnb_4bit_quant_type: str = "fp4",
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[list[str]] = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
if self.bnb_4bit_quant_storage not in ["uint8"]:
|
||||
raise ValueError("Unsupported bnb_4bit_quant_storage: "
|
||||
f"{self.bnb_4bit_quant_storage}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
||||
f"load_in_4bit={self.load_in_4bit}, "
|
||||
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
||||
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
|
||||
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
value = cls.get_from_keys(config, keys)
|
||||
return value if value is not None else default_value
|
||||
except ValueError:
|
||||
return default_value
|
||||
|
||||
load_in_8bit = get_safe_value(config, ["load_in_8bit"],
|
||||
default_value=False)
|
||||
load_in_4bit = get_safe_value(config, ["load_in_4bit"],
|
||||
default_value=True)
|
||||
bnb_4bit_compute_dtype = get_safe_value(config,
|
||||
["bnb_4bit_compute_dtype"],
|
||||
default_value="float32")
|
||||
bnb_4bit_quant_storage = get_safe_value(config,
|
||||
["bnb_4bit_quant_storage"],
|
||||
default_value="uint8")
|
||||
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
|
||||
default_value="fp4")
|
||||
bnb_4bit_use_double_quant = get_safe_value(
|
||||
config, ["bnb_4bit_use_double_quant"], default_value=False)
|
||||
llm_int8_enable_fp32_cpu_offload = get_safe_value(
|
||||
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
|
||||
llm_int8_has_fp16_weight = get_safe_value(config,
|
||||
["llm_int8_has_fp16_weight"],
|
||||
default_value=False)
|
||||
llm_int8_skip_modules = get_safe_value(config,
|
||||
["llm_int8_skip_modules"],
|
||||
default_value=[])
|
||||
llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
|
||||
default_value=6.0)
|
||||
|
||||
return cls(
|
||||
load_in_8bit=load_in_8bit,
|
||||
load_in_4bit=load_in_4bit,
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
||||
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
|
||||
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return BitsAndBytesMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
substr_check = any(module_name in components
|
||||
for module_name in llm_int8_skip_modules)
|
||||
|
||||
# Allow certain layers to not be quantized
|
||||
set_components = set(".".join(components[:i + 1])
|
||||
for i in range(len(components)))
|
||||
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
|
||||
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
|
||||
|
||||
return substr_check or prefix_check
|
||||
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
try:
|
||||
import bitsandbytes
|
||||
if version.parse(
|
||||
bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": 1,
|
||||
"use_bitsandbytes_8bit": True,
|
||||
"generation": 0
|
||||
})
|
||||
return qweight
|
||||
|
||||
def create_qweight_for_4bit():
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
|
||||
total_size = input_size_per_partition * sum(output_partition_sizes)
|
||||
if total_size % quant_ratio != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape.")
|
||||
|
||||
qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
|
||||
1,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 0,
|
||||
"pack_factor": quant_ratio,
|
||||
"use_bitsandbytes_4bit": True
|
||||
})
|
||||
return qweight
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
qweight = create_qweight_for_4bit()
|
||||
# Enable parameters to have the same name as in the BNB
|
||||
# checkpoint format.
|
||||
layer.register_parameter("weight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.load_in_8bit:
|
||||
return self._apply_8bit_weight(layer, x, bias)
|
||||
else:
|
||||
return self._apply_4bit_weight(layer, x, bias)
|
||||
|
||||
def _apply_8bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
generation = qweight.generation
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
||||
out = torch.empty(out_dim_0,
|
||||
out_dim_1,
|
||||
dtype=torch.float16,
|
||||
device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
|
||||
# in profile_run or the first generation of inference,
|
||||
# create new matmul_states
|
||||
if generation == 0 or generation == 1:
|
||||
matmul_states[i] = MatmulLtState()
|
||||
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
||||
matmul_states[i].SCB = quant_states[i].to(x.device)
|
||||
matmul_states[i].threshold = (
|
||||
self.quant_config.llm_int8_threshold)
|
||||
matmul_states[i].has_fp16_weights = (
|
||||
self.quant_config.llm_int8_has_fp16_weight)
|
||||
matmul_states[i].is_training = False
|
||||
if matmul_states[i].threshold > 0.0 and not matmul_states[
|
||||
i].has_fp16_weights:
|
||||
matmul_states[i].use_pool = True
|
||||
|
||||
new_x = bf_x.unsqueeze(0)
|
||||
|
||||
out[:, current_index:current_index + output_size] = matmul(
|
||||
new_x,
|
||||
qweight[offsets[i]:offsets[i + 1]],
|
||||
state=matmul_states[i])
|
||||
|
||||
current_index += output_size
|
||||
|
||||
# only update the matmul_states if it is not profile_run
|
||||
if (generation > 0
|
||||
and not self.quant_config.llm_int8_has_fp16_weight
|
||||
and matmul_states[i].CB is not None
|
||||
and matmul_states[i].CxB is not None):
|
||||
del matmul_states[i].CB
|
||||
qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
qweight.generation += 1
|
||||
|
||||
return out
|
||||
|
||||
def _apply_4bit_weight(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.weight
|
||||
quant_states = qweight.bnb_quant_state
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
|
||||
out_dim_0 = x.shape[0]
|
||||
out_dim_1 = sum(
|
||||
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
||||
out = torch.empty(out_dim_0,
|
||||
out_dim_1,
|
||||
dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _apply_bnb_4bit(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
quant_states = weight.bnb_quant_state
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
current_index += output_size
|
||||
|
||||
|
||||
def _apply_bnb_4bit_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(op_name="apply_bnb_4bit",
|
||||
op_func=_apply_bnb_4bit,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_apply_bnb_4bit_fake,
|
||||
dispatch_key=current_platform.dispatch_key)
|
||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: BitsAndBytesConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
try:
|
||||
import bitsandbytes
|
||||
if version.parse(
|
||||
bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.load_in_8bit:
|
||||
call_fun = self._create_weights_8bit
|
||||
else:
|
||||
call_fun = self._create_weights_4bit
|
||||
call_fun(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
if self.quant_config.load_in_8bit:
|
||||
w13, w2 = self._apply_8bit_dequant(layer)
|
||||
else:
|
||||
w13, w2 = self._apply_4bit_dequnt(layer)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
def _create_weights_4bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_total_size = (hidden_size * 2 *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
set_weight_attrs(
|
||||
w13_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
hidden_size,
|
||||
"output_dim":
|
||||
2 * intermediate_size_per_partition,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
# down_proj (row parallel)
|
||||
w2_total_size = (hidden_size *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w2_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
intermediate_size_per_partition,
|
||||
"output_dim":
|
||||
hidden_size,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
def _create_weights_8bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def _apply_4bit_dequnt(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
w13 = dequantize_4bit(
|
||||
layer.w13_weight.reshape(-1, 1),
|
||||
layer.w13_weight.bnb_quant_state,
|
||||
)
|
||||
w2 = dequantize_4bit(
|
||||
layer.w2_weight.reshape(-1, 1),
|
||||
layer.w2_weight.bnb_quant_state,
|
||||
)
|
||||
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
||||
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
||||
return w13, w2
|
||||
|
||||
def _apply_8bit_dequant(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,797 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: dict[str, Any],
|
||||
ignore: list[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
transform_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
if transform_config:
|
||||
self.transform_config = TransformConfig.model_validate(
|
||||
transform_config)
|
||||
else:
|
||||
self.transform_config = None
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "compressed-tensors"
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||
self.target_scheme_map)
|
||||
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
|
||||
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||
self.sparsity_scheme_map)
|
||||
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
|
||||
self.sparsity_ignore_list)
|
||||
if self.kv_cache_scheme is not None:
|
||||
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
|
||||
self.kv_cache_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
input_tfms, output_tfms = get_linear_transform_schemes(
|
||||
layer, prefix, self.transform_config,
|
||||
self.packed_modules_mapping)
|
||||
|
||||
# choose quantization method
|
||||
quant_method: LinearMethodBase = UnquantizedLinearMethod()
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
quant_method = CompressedTensorsLinearMethod(self)
|
||||
|
||||
# choose transform method
|
||||
if any((input_tfms, output_tfms)):
|
||||
return CompressedTensorsLinearTransformMethod.from_schemes(
|
||||
quant_method, quant_scheme, input_tfms, output_tfms)
|
||||
|
||||
else:
|
||||
return quant_method
|
||||
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: dict[str, Any]
|
||||
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict(), []
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
# an input_activations key with details about how the activations are
|
||||
# quantized, a weights key indicating how the weights are quantized,
|
||||
# and a list of targets under the `targets` key, dictating which
|
||||
# layers are impacted by the quantization details. The quantization
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target][
|
||||
"weights"] = QuantizationArgs.model_validate(
|
||||
quant_config.get("weights"))
|
||||
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
target_scheme_map[target]["format"] = quant_config.get(
|
||||
"format")
|
||||
format = target_scheme_map[target].get("format")
|
||||
# If no per-config format defined, use global format in config
|
||||
act_quant_format = is_activation_quantization_format(
|
||||
format
|
||||
) if format is not None else is_activation_quantization_format(
|
||||
quant_format)
|
||||
# TODO(czhu): w4a8fp8 is in packed-quantized format
|
||||
# but needs input activation quantization
|
||||
input_activations = quant_config.get("input_activations")
|
||||
if act_quant_format or input_activations:
|
||||
# The only case where we have activation quant supported
|
||||
# but no input_activations provided in the config
|
||||
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
||||
# there is an input_quant but it is ignored
|
||||
if not input_activations:
|
||||
assert target_scheme_map[target][
|
||||
"weights"].type == QuantizationType.FLOAT
|
||||
else:
|
||||
target_scheme_map[target][
|
||||
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
|
||||
quant_config.get("input_activations"))
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True,
|
||||
match_exact: bool = False) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if match_exact:
|
||||
supported = capability == min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
"the current GPU. Required capability: ",
|
||||
f"{min_capability}. Current capability: {capability}.")
|
||||
else:
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs):
|
||||
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_tensor_group_quant = (weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR_GROUP.value
|
||||
and input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR_GROUP.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
|
||||
is_group_size_16 = (weight_quant.group_size == 16
|
||||
and input_quant.group_size == 16)
|
||||
is_float_type = (weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT.value)
|
||||
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
|
||||
|
||||
return (is_tensor_group_quant and is_float_type and is_4_bits
|
||||
and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs):
|
||||
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value)
|
||||
is_symmetric = weight_quant.symmetric
|
||||
|
||||
is_group_size_16 = weight_quant.group_size == 16
|
||||
is_float_type = weight_quant.type == QuantizationType.FLOAT
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
|
||||
return (is_weight_only and is_tensor_group_quant and is_float_type
|
||||
and is_4_bits and is_group_size_16 and is_symmetric)
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_tensor = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and weight_quant.symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT)
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK
|
||||
])
|
||||
if not (is_floating_point and is_symmetric_weight and is_static_weight
|
||||
and is_tensor_or_channel_or_block_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.dynamic:
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_symmetric_activation = input_quant.symmetric
|
||||
is_per_tensor_activation = (
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w4a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
if not weight_quant or not input_quant:
|
||||
return False
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
# Only per-group symmetric weight (4bit)
|
||||
# + per-tok symmetric activation (8bit) quantization supported.
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and is_symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w4a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
return (self._check_scheme_supported(
|
||||
100, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if weight_quant.type != QuantizationType.FLOAT:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK
|
||||
])
|
||||
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
|
||||
and is_tensor_or_channel_or_block_weight):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
|
||||
is_static = not weight_quant.dynamic
|
||||
|
||||
return (is_channel_group and input_quant_none and is_static)
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: Optional[str] = None) -> "CompressedTensorsScheme":
|
||||
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size)
|
||||
if (format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
if cutlass_fp4_supported(
|
||||
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Current platform does not support cutlass NVFP4."
|
||||
" Running CompressedTensorsW4A16Fp4.")
|
||||
return CompressedTensorsW4A16Fp4(
|
||||
has_input_global_scale=True)
|
||||
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
weight_quant=weight_quant,
|
||||
is_static_input_scheme=(input_quant
|
||||
and not input_quant.dynamic))
|
||||
else:
|
||||
# note: input_quant will be present for converted models;
|
||||
# will be ignored during inference post loading
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=not input_quant.dynamic)
|
||||
|
||||
# note: input_quant can be None
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
is_static_input_scheme = (input_quant
|
||||
and not input_quant.dynamic)
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=is_static_input_scheme)
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||
is_static_input_scheme = (input_quant
|
||||
and not input_quant.dynamic)
|
||||
return CompressedTensorsW4A8Int(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for inference.
|
||||
"""
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@kylesayrs): support ignore module names with ct matching utils
|
||||
if should_ignore_layer(layer_name,
|
||||
ignore=self.ignore,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return None
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
format = scheme_dict.get("format")
|
||||
|
||||
# Find the sparsity scheme of the layer
|
||||
# assume that fused layers inherit first component's sparsity scheme
|
||||
sparsity_targets = (self.sparsity_scheme_map.keys() -
|
||||
set(self.sparsity_ignore_list))
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=sparsity_targets,
|
||||
fused_mapping=self.packed_modules_mapping)
|
||||
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme):
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (None if sparsity_scheme is None
|
||||
or sparsity_scheme.format == "dense"
|
||||
else self.config)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
elif weight_quant is None:
|
||||
logger.warning_once("Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
format=format)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
|
||||
layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
input_quant: Optional[QuantizationArgs],
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
Conditions:
|
||||
- Overarching condition: Sparsity Structure is 2:4
|
||||
- Unquantized cases are supported
|
||||
- Weight only quantization is not-supported
|
||||
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||
- Supported input quantization strategies are TENSOR and TOKEN
|
||||
- Only 8 bit quantization is supported
|
||||
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
if sparsity_scheme is None:
|
||||
return False
|
||||
|
||||
is_valid_sparsity_structure: bool = (
|
||||
sparsity_scheme.sparsity_structure ==
|
||||
SparsityStructure.TWO_FOUR.value)
|
||||
|
||||
valid_compressors = {
|
||||
CompressionFormat.dense.value,
|
||||
CompressionFormat.sparse_24_bitmask.value
|
||||
}
|
||||
|
||||
is_valid_sparsity = (is_valid_sparsity_structure
|
||||
and sparsity_scheme.format in valid_compressors)
|
||||
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
# Unquantized cases are supported
|
||||
if weight_quant is None and input_quant is None:
|
||||
return True
|
||||
|
||||
# Weight only quantization is not-supported
|
||||
if weight_quant is not None and input_quant is None:
|
||||
return False
|
||||
|
||||
supported_weight_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.CHANNEL.value
|
||||
]
|
||||
|
||||
assert weight_quant is not None
|
||||
assert input_quant is not None
|
||||
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||
return False
|
||||
|
||||
supported_input_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value
|
||||
]
|
||||
|
||||
if input_quant.strategy not in supported_input_quant_strategies:
|
||||
return False
|
||||
|
||||
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||
|
||||
|
||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: CompressedTensorsConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from compressed-tensors
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_scheme: the compressed-tensors kv cache scheme
|
||||
"""
|
||||
if kv_cache_scheme is None:
|
||||
return
|
||||
|
||||
type_ = kv_cache_scheme.get("type")
|
||||
num_bits = kv_cache_scheme.get("num_bits")
|
||||
|
||||
if type_ != "float" and num_bits != 8:
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
"num_bits=8, type=float, however "
|
||||
f"received num_bits={num_bits}, type={type_}")
|
||||
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}")
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
if not is_symmetric:
|
||||
raise NotImplementedError(
|
||||
"Only support symmetric scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16)
|
||||
|
||||
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme", "CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8"
|
||||
]
|
||||
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, sparse_cutlass_supported)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
self.model_compressor = (
|
||||
ModelCompressor.from_compression_config(model_compression_config)
|
||||
if model_compression_config is not None else None)
|
||||
self.do_sparse_decompress = (
|
||||
self.model_compressor is not None
|
||||
and self.model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value)
|
||||
|
||||
if quantized and input_quant is not None and \
|
||||
self._get_quant_dtype() == current_platform.fp8_dtype():
|
||||
static = not input_quant.dynamic
|
||||
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.quant_fp8 = QuantFP8(static, g_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(partition_size % 8 == 0
|
||||
for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL.value):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (self.weight_quant and self.weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert (self.input_quant.strategy ==
|
||||
QuantizationStrategy.TENSOR.value)
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
input_scale = torch.nn.Parameter(data=torch.ones(
|
||||
1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if (layer.weight.dtype.is_floating_point
|
||||
and layer.weight.dtype.itemsize >= 2):
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
q_input, input_scale = self.quant_fp8(x, scale=scale)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
return self._get_quant_dtype()
|
||||
|
||||
def _get_quant_dtype(self) -> torch.dtype:
|
||||
assert self.quantized
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [(out, layer.input_size_per_partition)
|
||||
for out in layer.logical_widths]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
))
|
||||
return decompressed
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["CompressedTensorsScheme"]
|
||||
|
||||
|
||||
class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be given when using strategy group")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
layer.scale_packed = Parameter(layer.scale_packed.data,
|
||||
requires_grad=False)
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // self.tile_size // 2,
|
||||
output_size_per_partition * self.tile_size // pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=pack_factor,
|
||||
marlin_tile_size=self.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
input_groups = (1 if self.group_size is None else
|
||||
input_size_per_partition // self.group_size)
|
||||
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if self.group_size is not None:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
meta = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", qweight)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("scale_packed", scales)
|
||||
layer.register_parameter("meta", meta)
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
|
||||
requires_grad=False)
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
scales = layer.scale_packed
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace, self.quant_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, has_input_global_scale: bool = False):
|
||||
self.has_input_global_scale = has_input_global_scale
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
if self.has_input_global_scale:
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_scale_2 = Parameter(
|
||||
1 / layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
del layer.weight_global_scale
|
||||
|
||||
if self.has_input_global_scale:
|
||||
layer.input_global_scale = torch.nn.Parameter(
|
||||
layer.input_global_scale.data, requires_grad=False)
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
swizzle_blockscale)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
import fbgemm_gpu # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return 80
|
||||
return 100
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
|
||||
global_input_scale = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(global_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
weight = layer.weight_packed.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
if self.backend == "fbgemm":
|
||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(
|
||||
torch.uint8)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
layer.alpha = Parameter(
|
||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||
requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
||||
layer.weight_scale, layer.alpha, output_dtype)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale.view(-1).view(torch.uint8),
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Fp8"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size != 128 or self.strategy != "group":
|
||||
raise ValueError("W4A8 kernels require group quantization " \
|
||||
"with group size 128")
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# hopper
|
||||
return 90
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx,
|
||||
out_type=params_dtype
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Fp8",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition //
|
||||
self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
))
|
||||
|
||||
# TODO(czhu): allocate the packed fp8 scales memory here?
|
||||
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# per-channel scales
|
||||
weight_chan_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((output_size_per_partition, 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("weight_chan_scale", weight_chan_scale)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Int"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
is_static_input_scheme: bool = False,
|
||||
input_symmetric: bool = True):
|
||||
self.strategy = strategy
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}."
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
|
||||
# Compute effective group_size
|
||||
if self.group_size == -1:
|
||||
effective_group_size = (input_size_per_partition
|
||||
if row_parallel else input_size)
|
||||
else:
|
||||
effective_group_size = self.group_size
|
||||
|
||||
# Ensure group_size divides input_size_per_partition
|
||||
assert input_size_per_partition % effective_group_size == 0, (
|
||||
f"input_size_per_partition {input_size_per_partition}"
|
||||
f" not divisible by group_size {effective_group_size}")
|
||||
|
||||
# Determine scale partitioning
|
||||
is_channelwise = (self.group_size == -1)
|
||||
repeat_scales = (is_channelwise and row_parallel)
|
||||
partition_scales = not repeat_scales
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(input_size_per_partition,
|
||||
output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=effective_group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Int",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
scales_and_zp_size = input_size_per_partition // effective_group_size
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype)
|
||||
}
|
||||
|
||||
if partition_scales:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_gidx_param_name=None)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [
|
||||
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
|
||||
requires_grad=False)
|
||||
else:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(),
|
||||
requires_grad=False)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}")
|
||||
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return apply_fp8_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy)
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
|
||||
process_fp8_weight_tensor_strategy, validate_fp8_block_shape)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, weight_quant: QuantizationArgs,
|
||||
is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
weight_loader: Callable, **kwargs):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(layer, input_size, output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy], output_partition_sizes,
|
||||
input_size_per_partition, layer.weight_block_size, weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes,
|
||||
weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
layer.weight, layer.weight_scale, layer.logical_widths,
|
||||
getattr(layer, 'input_scale', None)))
|
||||
weight = weight.t()
|
||||
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_channel_strategy(
|
||||
layer.weight, layer.weight_scale,
|
||||
getattr(layer, 'input_scale', None)))
|
||||
weight = weight.t()
|
||||
|
||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale)
|
||||
input_scale = None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(
|
||||
layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if layer.weight_block_size is not None:
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
input=x,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||
input_symmetric: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8Int8",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128
|
||||
}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError("Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise.")
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric else
|
||||
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsWNA16",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition //
|
||||
self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx")
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Generator
|
||||
from itertools import accumulate
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (TransformArgs, TransformConfig,
|
||||
TransformLocation, TransformScheme)
|
||||
from compressed_tensors.utils import is_match
|
||||
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearMethodBase,
|
||||
QKVCrossParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
|
||||
HadamardTransform)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple)
|
||||
|
||||
|
||||
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
"""
|
||||
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
|
||||
input and output transforms to either side of the original apply method
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_schemes(
|
||||
cls,
|
||||
quant_method: LinearMethodBase,
|
||||
quant_scheme: Optional[CompressedTensorsScheme],
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
) -> "CompressedTensorsLinearTransformMethod":
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
|
||||
QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme)
|
||||
|
||||
assert input_tfms or output_tfms
|
||||
|
||||
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
|
||||
return QutlassNvFP4LinearMethod(quant_method, input_tfms,
|
||||
output_tfms)
|
||||
|
||||
# hadacore or dense gemm is selected by Transform module
|
||||
|
||||
return cls(quant_method, input_tfms, output_tfms)
|
||||
|
||||
def __init__(self, quant_method: LinearMethodBase,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple]):
|
||||
self.quant_method = quant_method
|
||||
self.input_tfms = input_tfms
|
||||
self.output_tfms = output_tfms
|
||||
|
||||
self.input_transform: Optional[HadamardTransform] = None
|
||||
self.output_transform: Optional[HadamardTransform] = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
# get weight loader for transforms
|
||||
weight_loader: Callable = extra_weight_attrs.get(
|
||||
"weight_loader") # type: ignore[assignment]
|
||||
|
||||
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
|
||||
# transforms (specifically SharedWeightParameter) requires
|
||||
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
|
||||
# hack around this by getting weight loader v1 so ULM can load correctly
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
if isinstance(layer, QKVCrossParallelLinear):
|
||||
weight_loader_v1 = layer.weight_loader_v1
|
||||
else:
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
extra_weight_attrs["weight_loader"] = weight_loader_v1
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# validate schemes
|
||||
num_partitions = len(output_partition_sizes)
|
||||
self._validate_tfm_schemes(num_partitions)
|
||||
|
||||
# create submodules for weight loading
|
||||
if len(self.input_tfms) > 0:
|
||||
scheme_name = list(self.input_tfms.values())[0].scheme_name
|
||||
location = list(self.input_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(self.input_tfms, layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.input_transform = transform
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(self.output_tfms, layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.output_transform = transform
|
||||
|
||||
# compute partition ranges for slicing activations
|
||||
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
|
||||
self.partition_ranges = list(zip(starts, output_partition_sizes))
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
for submodule in layer.children():
|
||||
if isinstance(submodule, HadamardTransform):
|
||||
submodule.process_weights_after_loading()
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.input_transform is not None:
|
||||
x = self.input_transform(x)
|
||||
|
||||
assert bias is None
|
||||
x = self.quant_method.apply(layer, x, bias)
|
||||
|
||||
# In most cases, input transforms are preferred over output transforms
|
||||
# (@ksayers): confirm that this is done concurrently
|
||||
if self.output_transform is not None:
|
||||
for part_id, (start, length) in enumerate(self.partition_ranges):
|
||||
x[:, start:start + length] = self.output_transform(
|
||||
x[:, start:start + length].contiguous(), part_id=part_id)
|
||||
|
||||
return x
|
||||
|
||||
def _validate_tfm_schemes(self, num_partitions: int):
|
||||
if len(self.input_tfms) > 0:
|
||||
if 0 not in self.input_tfms:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
for part_index in range(num_partitions):
|
||||
if self.input_tfms[part_index] != self.input_tfms[0]:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
|
||||
for tfm in self.output_tfms.values():
|
||||
if tfm.scheme_name != scheme_name:
|
||||
raise ValueError("Must have same scheme name")
|
||||
if tfm.args.location != location:
|
||||
raise ValueError("Must have same location")
|
||||
|
||||
return self.input_tfms, self.output_tfms
|
||||
|
||||
|
||||
def get_linear_transform_schemes(
|
||||
layer: torch.nn.Module, layer_name: str,
|
||||
transform_config: Optional[TransformConfig],
|
||||
packed_modules_mapping: dict[str, list[str]]
|
||||
) -> tuple[dict[int, TransformTuple], dict[
|
||||
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
|
||||
# there can only be one transform input scheme per (fused) module
|
||||
input_tfms = {}
|
||||
output_tfms = {}
|
||||
|
||||
partition_names = get_layer_partition_names(layer_name,
|
||||
packed_modules_mapping)
|
||||
|
||||
for scheme_name, scheme, args in get_schemes_args(transform_config):
|
||||
for part_index, part_name in enumerate(partition_names):
|
||||
if is_match(part_name, layer, args.targets,
|
||||
args.ignore) and args.is_online():
|
||||
if args.location == TransformLocation.INPUT:
|
||||
input_tfms[part_index] = TransformTuple(
|
||||
scheme_name, scheme, args)
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
output_tfms[part_index] = TransformTuple(
|
||||
scheme_name, scheme, args)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Cannot apply `{args.location}` "
|
||||
f"transform to `{layer_name}`")
|
||||
|
||||
return (input_tfms, output_tfms)
|
||||
|
||||
|
||||
def get_schemes_args(
|
||||
transform_config: Optional[TransformConfig]
|
||||
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
|
||||
if transform_config is None:
|
||||
return
|
||||
|
||||
for scheme_name, scheme in transform_config.config_groups.items():
|
||||
for args in scheme.apply:
|
||||
yield (scheme_name, scheme, args)
|
||||
|
||||
|
||||
def get_layer_partition_names(
|
||||
layer_name: str, packed_modules_mapping: dict[str,
|
||||
list[str]]) -> list[str]:
|
||||
"""
|
||||
Get all partition names associated with this layer.
|
||||
Names are returned in order of their partition indices.
|
||||
|
||||
```python
|
||||
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
|
||||
|
||||
assert get_layer_partition_names(
|
||||
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
|
||||
assert get_layer_partition_names(
|
||||
"mlp.down_proj", mapping) == ["down_proj"]
|
||||
"""
|
||||
for fused_suffix, part_suffixes in packed_modules_mapping.items():
|
||||
if layer_name.endswith(fused_suffix):
|
||||
return [
|
||||
layer_name.removesuffix(fused_suffix) + part_suffix
|
||||
for part_suffix in part_suffixes
|
||||
]
|
||||
|
||||
return [layer_name]
|
||||
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Hashable
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (TransformArgs, TransformLocation,
|
||||
TransformScheme)
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parameter import SharedWeightParameter
|
||||
|
||||
|
||||
class HadamardTransform(torch.nn.Module):
|
||||
"""
|
||||
Class which handles weight loading, postprocessing, and application of
|
||||
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
|
||||
and attention transforms method (not implemented yet)
|
||||
"""
|
||||
transforms: dict[int, TransformTuple] # info parsed from transforms config
|
||||
weight: SharedWeightParameter # container for shared tensors
|
||||
|
||||
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
|
||||
|
||||
def __init__(self, transforms: dict[int, TransformTuple],
|
||||
layer: torch.nn.Module, weight_loader: Callable,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int]):
|
||||
super().__init__()
|
||||
self.transforms = transforms
|
||||
self.scales = {}
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise NotImplementedError("Online transforms with tensor "
|
||||
"parallelism is not supported")
|
||||
|
||||
# Similar to row/col parallel params, but tensors are separate
|
||||
# to allow for loading with shared memory
|
||||
self.weight = SharedWeightParameter(weight_loader=weight_loader)
|
||||
|
||||
# create shared partition data for each partition of the original weight
|
||||
input_size = input_size_per_partition
|
||||
for part_index, (_scheme_name, scheme,
|
||||
args) in self.transforms.items():
|
||||
output_size = output_partition_sizes[part_index]
|
||||
weight_size = self._get_weight_size(layer, scheme, args,
|
||||
input_size, output_size)
|
||||
|
||||
data_key = self._get_data_key(scheme, weight_size)
|
||||
self.weight.add_partition(
|
||||
part_index,
|
||||
data_key,
|
||||
size=(weight_size, weight_size),
|
||||
dtype=scheme.precision,
|
||||
)
|
||||
|
||||
# validate that shared tensors and schemes are correct
|
||||
self._validate_input_transforms()
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for part_id in self.weight.partitions:
|
||||
data = self.weight.partitions[part_id].data
|
||||
|
||||
# required by torch.compile
|
||||
self.weight.process_weights_after_loading()
|
||||
|
||||
# precompute scale as a runtime multiply, not division
|
||||
# do not fold into weight in order to utilize FWHT
|
||||
self.scales[part_id] = 1 / math.sqrt(data.size(0))
|
||||
|
||||
# FUTURE: avoid runtime transpose by processing weights
|
||||
# prior to apply
|
||||
|
||||
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
|
||||
if part_id not in self.weight.partitions:
|
||||
return value
|
||||
|
||||
# use hadacore if possible
|
||||
if self.transforms[part_id].scheme.type == "hadamard":
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
weight_size = self.transforms[part_id].scheme.head_dim
|
||||
value = value.unflatten(-1, (-1, weight_size))
|
||||
value = ops.hadacore_transform(value)
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
# sylvester transforms are symmetric, inv => transpose => original
|
||||
return ops.hadacore_transform(value)
|
||||
|
||||
# fall back to dense
|
||||
else:
|
||||
weight = self.weight.partitions[part_id]
|
||||
weight = weight if self.transforms[
|
||||
part_id].args.inverse else weight.T # linear := x(W.T)
|
||||
scale = self.scales[part_id]
|
||||
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
value = value.unflatten(-1, (-1, weight.size(0)))
|
||||
value = dispatch_unquantized_gemm()(self, value.to(
|
||||
weight.dtype), weight, None).to(value.dtype) * scale
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
return dispatch_unquantized_gemm()(self, value.to(
|
||||
weight.dtype), weight, None).to(value.dtype) * scale
|
||||
|
||||
def _get_data_key(self, scheme: TransformScheme,
|
||||
weight_size: int) -> Hashable:
|
||||
return (id(scheme), weight_size)
|
||||
|
||||
def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme,
|
||||
args: TransformArgs, input_size: int,
|
||||
output_size: int) -> int:
|
||||
if scheme.head_dim is not None:
|
||||
return scheme.head_dim
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return input_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return output_size
|
||||
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return output_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return input_size
|
||||
|
||||
raise ValueError()
|
||||
|
||||
def _validate_input_transforms(self):
|
||||
assert len(self.transforms) > 0
|
||||
location = list(self.transforms.values())[0].args.location
|
||||
|
||||
if location == TransformLocation.INPUT:
|
||||
first_data = self.weight.partitions[0].data
|
||||
for partition in self.weight.partitions.values():
|
||||
if partition.data.data_ptr() != first_data.data_ptr():
|
||||
raise ValueError("")
|
||||
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod, TransformTuple)
|
||||
|
||||
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
|
||||
|
||||
|
||||
def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme],
|
||||
input_tfms: dict[int, TransformTuple]) -> bool:
|
||||
return isinstance(
|
||||
quant_scheme,
|
||||
(CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[
|
||||
0].scheme.head_dim == quant_scheme.group_size
|
||||
|
||||
|
||||
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
|
||||
|
||||
def create_weights(self, layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size, output_size,
|
||||
params_dtype, **extra_weight_attrs):
|
||||
# initializes fp4 qparams
|
||||
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, ))
|
||||
ret = super().create_weights(layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size,
|
||||
output_size, params_dtype,
|
||||
**extra_weight_attrs)
|
||||
|
||||
assert self.input_transform is not None
|
||||
assert len(self.input_transform.weight) == 1
|
||||
assert self.input_transform.weight[0].size(
|
||||
0) == layer.scheme.group_size
|
||||
|
||||
return ret
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
from compressed_tensors.transform import TransformArgs, TransformScheme
|
||||
|
||||
__all__ = ["TransformTuple"]
|
||||
|
||||
|
||||
class TransformTuple(NamedTuple):
|
||||
scheme_name: str
|
||||
scheme: TransformScheme
|
||||
args: TransformArgs
|
||||
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def is_weak_contiguous(x: torch.Tensor):
|
||||
strides = x.stride()
|
||||
sizes = x.shape
|
||||
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
||||
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
||||
return is_transpose or is_not_transpose
|
||||
|
||||
|
||||
@triton.jit
|
||||
def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
|
||||
M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
|
||||
stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_B: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = ACCUMULATOR_DTYPE
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
|
||||
dtype=accumulator_dtype)
|
||||
|
||||
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
||||
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
||||
# eventually occur.
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
masks_bn = offsets_bn < N
|
||||
|
||||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
offsets_a = (stride_am * offsets_am[:, None] +
|
||||
stride_ak * offsets_k[None, :])
|
||||
offsets_b = (stride_bk * offsets_k[:, None] +
|
||||
stride_bn * offsets_bn[None, :])
|
||||
|
||||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
||||
# appropriate offsets and masks for each case. Same goes for
|
||||
# BLOCK_SIZE_SCALE_B.
|
||||
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
|
||||
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
|
||||
masks_scale_am = offsets_scale_am < M
|
||||
|
||||
offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
|
||||
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
|
||||
masks_scale_bn = offsets_scale_bn < N
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
||||
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Apply scale at end.
|
||||
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
||||
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
||||
# Need to broadcast to the appropriate size, if scale_a is already
|
||||
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
||||
# for scale_b below.
|
||||
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
||||
accumulator = scale_a * accumulator.to(tl.float32)
|
||||
|
||||
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
||||
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
||||
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
||||
accumulator = scale_b.T * accumulator.to(tl.float32)
|
||||
|
||||
# Convert to output format.
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
|
||||
# Add bias, it's already in output format, so add it after conversion.
|
||||
if bias_ptr:
|
||||
offsets_bias = offsets_bn
|
||||
bias_ptrs = bias_ptr + offsets_bias
|
||||
bias_mask = offsets_bias < N
|
||||
bias = tl.load(bias_ptrs, bias_mask)
|
||||
c += bias
|
||||
|
||||
# Save output
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
|
||||
stride_cn * offs_cn[None, :])
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# weight - [K, N]
|
||||
def triton_scaled_mm(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
use_heuristic=True) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = weight.shape[1]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert weight.shape[0] == K
|
||||
assert input.dtype == weight.dtype
|
||||
|
||||
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||
|
||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1
|
||||
or scale_a.shape[0] == M)
|
||||
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1
|
||||
or scale_b.shape[0] == N)
|
||||
assert out_dtype.is_floating_point
|
||||
assert bias is None or bias.is_floating_point()
|
||||
assert is_weak_contiguous(input)
|
||||
assert is_weak_contiguous(weight)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
||||
|
||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
||||
|
||||
if use_heuristic:
|
||||
is_small_N = N < 8192
|
||||
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
||||
if next_power_of_2_M <= 32:
|
||||
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
||||
elif next_power_of_2_M <= 64:
|
||||
tile_shape = (64, 64, 256)
|
||||
elif next_power_of_2_M <= 128:
|
||||
tile_shape = (64, 128, 128)
|
||||
else:
|
||||
tile_shape = (128, 128, 128)
|
||||
|
||||
block_size_m, block_size_n, block_size_k = tile_shape
|
||||
|
||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
||||
|
||||
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
||||
|
||||
# A = input, B = weight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
scaled_mm_kernel[grid](input,
|
||||
weight,
|
||||
scale_a,
|
||||
scale_b,
|
||||
result,
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
weight.stride(0),
|
||||
weight.stride(1),
|
||||
result.stride(0),
|
||||
result.stride(1),
|
||||
accumulator_dtype,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_SCALE_A=block_size_sa,
|
||||
BLOCK_SIZE_SCALE_B=block_size_sb)
|
||||
|
||||
return result.to(out_dtype)
|
||||
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.nvfp4_pack_quantized.value
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config.")
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(value: str,
|
||||
targets: Iterable[str],
|
||||
check_contains: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value,
|
||||
target,
|
||||
check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str, target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)),
|
||||
None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: list[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
196
vllm/model_executor/layers/quantization/deepspeedfp.py
Normal file
196
vllm/model_executor/layers/quantization/deepspeedfp.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class DeepSpeedFPConfig(QuantizationConfig):
|
||||
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
|
||||
|
||||
Args:
|
||||
weight_bits: the target quantization bits, 6 or 8.
|
||||
group_size: group size for quantizaiton, default to 128.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int = 8,
|
||||
group_size: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.valid_types = [torch.bfloat16, torch.float16]
|
||||
|
||||
if self.weight_bits not in (6, 8):
|
||||
raise ValueError(
|
||||
"Currently, only 6-bit or 8-bit weight quantization are "
|
||||
f"supported for DeepSpeed FP quantizaiton, but got "
|
||||
f"{self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
|
||||
f"group_size={self.group_size}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits=weight_bits, group_size=group_size)
|
||||
|
||||
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class DeepSpeedFPLinearMethod(LinearMethodBase):
|
||||
"""Linear method for DeepSpeedFP quantizer.
|
||||
|
||||
Args:
|
||||
quant_config: the DeepSpeedFP quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: DeepSpeedFPConfig):
|
||||
self.quant_config = quant_config
|
||||
self.weight = None
|
||||
|
||||
def create_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader=None,
|
||||
**extra_weight_attrs):
|
||||
del output_size
|
||||
del input_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = DeepSpeedFPParameter(
|
||||
torch.Size((output_size_per_partition, input_size_per_partition)),
|
||||
params_dtype=params_dtype,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def quant_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# Calls the original weight loader (if any), quantizes the result,
|
||||
# and then loads the quantized parameter.
|
||||
if weight_loader is not None:
|
||||
orig_param_data = param.data
|
||||
param.data = param.ds_dequantize()
|
||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||
param.data, loaded_weight = orig_param_data, param.data
|
||||
param.ds_quantize_(loaded_weight.cuda())
|
||||
|
||||
extra_weight_attrs["weight_loader"] = quant_weight_loader
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
y = weight.ds_dequantize()
|
||||
return F.linear(x, y, bias)
|
||||
|
||||
|
||||
class DeepSpeedFPParameter(nn.Parameter):
|
||||
"""
|
||||
DeepSpeedFP quantized parameter class that implements fp8/fp6
|
||||
quantization deepspeed. Weights are stored in quantized form on
|
||||
GPUs, and can be dequantized on-the-fly when needed by the model.
|
||||
"""
|
||||
|
||||
def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
|
||||
quant_config: DeepSpeedFPConfig):
|
||||
try:
|
||||
import deepspeed
|
||||
if version.parse(deepspeed.__version__) < version.parse("0.14.2"):
|
||||
raise ImportError("deepspeed version is wrong. Please "
|
||||
"install deepspeed>=0.14.2.")
|
||||
from deepspeed.ops.fp_quantizer import FP_Quantize
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install deepspeed>=0.14.2 via "
|
||||
"`pip install deepspeed>=0.14.2` to use "
|
||||
"deepspeedfp quantizer.") from err
|
||||
data = torch.empty((
|
||||
orig_shape.numel() // quant_config.group_size,
|
||||
quant_config.group_size * quant_config.weight_bits // 8 + 4,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
self.orig_shape = orig_shape
|
||||
self.quant_config = quant_config
|
||||
self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
|
||||
self.fp_quantizer.orig_shape = orig_shape
|
||||
self.fp_quantizer.orig_dtype = params_dtype
|
||||
return self
|
||||
|
||||
def ds_quantize_(self, tensor: torch.Tensor):
|
||||
assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
|
||||
return self.data.copy_(
|
||||
self.fp_quantizer.quantize(
|
||||
tensor.data,
|
||||
q_bits=self.quant_config.weight_bits,
|
||||
))
|
||||
|
||||
def ds_dequantize(self, fp_out=None) -> torch.Tensor:
|
||||
"""
|
||||
Return a tensor containing the dequantized weights of this parameter.
|
||||
"""
|
||||
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
|
||||
return self.fp_quantizer.dequantize(
|
||||
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
|
||||
|
||||
def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
|
||||
"""
|
||||
Return a tensor where only the weights at `indices` are dequantized
|
||||
(to save HBM -> SRAM bandwidth).
|
||||
"""
|
||||
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
|
||||
return self.fp_quantizer.selective_dequantize(
|
||||
self.data,
|
||||
indices,
|
||||
fp_out=fp_out,
|
||||
q_bits=self.quant_config.weight_bits)
|
||||
223
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
223
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class ExpertsInt8Config(QuantizationConfig):
|
||||
"""Config class for Int8 experts quantization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ExpertsInt8MoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ExpertsInt8Config,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
int8_dtype = torch.int8
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
|
||||
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: int,
|
||||
expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
device = get_tp_group().device
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
# w1, gate_proj case: Load into first shard of w13.
|
||||
if shard_id == "w1":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
|
||||
0])
|
||||
# w3, up_proj case: Load into second shard of w13.
|
||||
elif shard_id == "w3":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, shard_size:2 *
|
||||
shard_size].copy_(scales[:, 0])
|
||||
# w2, down_proj case: Load into only shard of w2.
|
||||
elif shard_id == "w2":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[:,
|
||||
shard])
|
||||
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||
expert_id)
|
||||
|
||||
return quantize_and_call_weight_loader
|
||||
|
||||
|
||||
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
|
||||
vmax = torch.iinfo(torch.int8).max
|
||||
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
|
||||
|
||||
weight.div_(scales)
|
||||
weight.round_()
|
||||
weight.clamp_(-vmax, vmax)
|
||||
|
||||
return scales
|
||||
172
vllm/model_executor/layers/quantization/fbgemm_fp8.py
Normal file
172
vllm/model_executor/layers/quantization/fbgemm_fp8.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = not current_platform.has_device_capability(89)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix=prefix,
|
||||
ignored_layers=self.ignore_list,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE UPPER BOUND
|
||||
input_scale_ub = torch.nn.Parameter(torch.tensor(
|
||||
(self.quant_config.input_scale_ub), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.input_scale_ub = input_scale_ub
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale_ub
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=None,
|
||||
input_scale_ub=layer.input_scale_ub,
|
||||
bias=bias)
|
||||
1098
vllm/model_executor/layers/quantization/fp8.py
Normal file
1098
vllm/model_executor/layers/quantization/fp8.py
Normal file
File diff suppressed because it is too large
Load Diff
599
vllm/model_executor/layers/quantization/gguf.py
Normal file
599
vllm/model_executor/layers/quantization/gguf.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self,
|
||||
unquantized_modules: Optional[list[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.unquantized_modules = unquantized_modules or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GGUFMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
||||
return any(module_name in prefix for module_name in unquantized_modules)
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
WeightType.Q4_1,
|
||||
WeightType.Q5_0,
|
||||
WeightType.Q5_1,
|
||||
WeightType.Q8_0,
|
||||
WeightType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
WeightType.Q2_K,
|
||||
WeightType.Q3_K,
|
||||
WeightType.Q4_K,
|
||||
WeightType.Q5_K,
|
||||
WeightType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
WeightType.IQ1_M,
|
||||
WeightType.IQ1_S,
|
||||
WeightType.IQ2_XXS,
|
||||
WeightType.IQ2_XS,
|
||||
WeightType.IQ2_S,
|
||||
WeightType.IQ3_XXS,
|
||||
WeightType.IQ3_S,
|
||||
WeightType.IQ4_XS,
|
||||
WeightType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
if qweight_type in IMATRIX_QUANT_TYPES:
|
||||
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
|
||||
else:
|
||||
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
return torch.empty(x.shape[0],
|
||||
qweight.shape[0],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0],
|
||||
qweight.shape[0],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def act(x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(out, x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
return out
|
||||
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
moe_align_block_size)
|
||||
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
# unless we decent expert reuse we are better off running moe_vec kernel
|
||||
if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES
|
||||
and x.shape[0] > 64):
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
|
||||
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type, N, top_k,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type2,
|
||||
w2.shape[1], 1, num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
|
||||
out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
|
||||
w2.shape[1], num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once("There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
"Falling back to slow implementation. ")
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = fused_mul_mat_gguf(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
out_hidden_states[tok] = current_hidden_state
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def _fused_moe_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _apply_gguf_embedding(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return torch.embedding(qweight, x)
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
x_flat = x.flatten()
|
||||
assert (hidden_size == qweight.shape[1] // type_size * block_size)
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0], dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
else:
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
|
||||
|
||||
def _apply_gguf_embedding_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
})
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"shard_weight_type": {},
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("qweight_type", qweight_type)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
if not (qweight_type in UNQUANTIZED_TYPES
|
||||
or qweight_type in DEQUANT_TYPES):
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise ValueError(
|
||||
f"Unsupported GGUF quantization type {qweight_type} in "
|
||||
f"layer {layer}.")
|
||||
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
|
||||
# materialize the padded weight parameter for CUDA Graph compatibility.
|
||||
self._create_padded_weight_param(layer)
|
||||
|
||||
def _create_padded_weight_param(self, layer: torch.nn.Module):
|
||||
"""Create padded weight parameter for GGUF MergedLinear layer."""
|
||||
qweight = layer.qweight
|
||||
shard_id_map = qweight.shard_id_map
|
||||
shard_id = qweight.shard_id
|
||||
if len(data_container := qweight.data_container) > 1:
|
||||
dtype = {data.dtype for data in data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}")
|
||||
dtype = next(iter(dtype))
|
||||
# concat dim0 and pad dim1
|
||||
padded_side = max(x.size(1) for x in data_container)
|
||||
concat_side = sum(x.size(0) for x in data_container)
|
||||
# Pad the quantized weights to dense tensor, and create a map
|
||||
# with the location of each shard in the padded tensor.
|
||||
padded_data = torch.zeros((concat_side, padded_side),
|
||||
dtype=dtype,
|
||||
device=qweight.device)
|
||||
# (dim0_start, dim0_end, dim1_size)
|
||||
shard_offset_map = dict[str, tuple[int, int, int]]()
|
||||
for idx in shard_id:
|
||||
id_in_container = shard_id_map[idx]
|
||||
start = sum(
|
||||
x.size(0) for x in data_container[:id_in_container])
|
||||
end = start + data_container[id_in_container].size(0)
|
||||
size = data_container[id_in_container].size(1)
|
||||
padded_data[start:end, :size] = data_container[id_in_container]
|
||||
shard_offset_map[idx] = (start, end, size)
|
||||
qweight.data_container.clear()
|
||||
padded_param = Parameter(padded_data, requires_grad=False)
|
||||
set_weight_attrs(padded_param, vars(qweight))
|
||||
set_weight_attrs(padded_param,
|
||||
{"shard_offset_map": shard_offset_map})
|
||||
layer.register_parameter("qweight", padded_param)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
shard_id = layer.qweight.shard_id
|
||||
|
||||
if shard_id:
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight
|
||||
result = []
|
||||
for idx in shard_id:
|
||||
start, end, offset = layer.qweight.shard_offset_map[idx]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
||||
result.append(
|
||||
fused_mul_mat_gguf(
|
||||
x, qweight[start:end, :offset].contiguous(),
|
||||
qweight_type))
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
out = fused_mul_mat_gguf(x, qweight, qweight_type)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
|
||||
|
||||
class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GGUFConfig,
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate up proj
|
||||
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
|
||||
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w13_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
||||
|
||||
tensor_shape = (num_experts, intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate down proj
|
||||
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
|
||||
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w2_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GGUFMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method.")
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, activation)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""Embedding method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
return apply_gguf_embedding(x,
|
||||
qweight,
|
||||
qweight_type,
|
||||
hidden_size,
|
||||
dtype=self.params_dtype)
|
||||
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: list[torch.Tensor]
|
||||
340
vllm/model_executor/layers/quantization/gptq.py
Normal file
340
vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic, autoround_version, modules_in_block_to_quantize)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"sym": True, # GPTQ typically uses symmetric quantization
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||
layer, prefix)
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize)
|
||||
|
||||
def maybe_update_config(self,
|
||||
model_name: str,
|
||||
revision: Optional[str] = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name,
|
||||
revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get('dtype', None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
448
vllm/model_executor/layers/quantization/gptq_bitblas.py
Normal file
448
vllm/model_executor/layers/quantization/gptq_bitblas.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
BitBLASLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
|
||||
check_bitblas_supported, verify_bitblas_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQBitBLASConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ BitBLAS"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
TORCH_DTYPE = torch.float16
|
||||
GPTQ_CKPT_STORAGE_DTYPE = (
|
||||
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
|
||||
)
|
||||
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
|
||||
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# the gptq_bitblas prefer "quantized"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if version.parse(bitblas.__version__) < version.parse(
|
||||
MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
|
||||
if c.isdigit()))
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})"
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||
or user_quant == "gptq_bitblas")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info("Detected that the model can run with gptq_bitblas"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_bitblas for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQBitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
@property
|
||||
def torch_storage_dtype(self) -> torch.dtype:
|
||||
return self.TORCH_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
@classmethod
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
# temporarily disable on ROCm platform
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
# Otherwise, can convert if model satisfies bitblas constraints.
|
||||
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
|
||||
sym)],
|
||||
group_size=group_size)
|
||||
|
||||
|
||||
class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ BitBLAS quantization config.
|
||||
"""
|
||||
|
||||
kernel_type = BitBLASLinearKernel
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
# Verify supported on platform.
|
||||
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing
|
||||
quantized weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_partition_sizes: The size of the output partition.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or if the input
|
||||
size per partition is not divisible by the group size
|
||||
in `quant_config`.
|
||||
"""
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({self.quant_config.group_size})."
|
||||
)
|
||||
|
||||
kernel_type = self.kernel_type
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQBitBLASLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Init buffers
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Activation order
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Scales
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"input_dim": scales_and_zp_input_dim,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Quantized zero-points
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx",
|
||||
bitblas_quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self.kernel.configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
751
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
751
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
@@ -0,0 +1,751 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_dynamic_override, get_linear_quant_method, override_config)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
moe_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix) == False: # noqa: E712
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return moe_method_cls(cloned_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = full_config.get("autoround_version", "")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized, dynamic, config,
|
||||
modules_in_block_to_quantize)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "gptq_marlin")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info("Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return get_moe_quant_method(self, layer, prefix,
|
||||
GPTQMarlinMoEMethod)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize)
|
||||
|
||||
def maybe_update_config(self,
|
||||
model_name: str,
|
||||
revision: Optional[str] = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name,
|
||||
revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get('dtype', None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQMarlinLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: GPTQMarlinConfig,
|
||||
moe: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError(
|
||||
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size_full = extra_weight_attrs.pop(
|
||||
"intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size_full)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
w2_scales_size = (intermediate_size_full
|
||||
if self.quant_config.desc_act else
|
||||
intermediate_size_per_partition)
|
||||
scales_size2 = (w2_scales_size // self.quant_config.group_size)
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": True
|
||||
})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# don't shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_scales,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
# don't shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_qzeros,
|
||||
{"load_full_w2": self.quant_config.desc_act})
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(
|
||||
layer.w13_g_idx[e]).to(torch.int32)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
|
||||
w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||
w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1] *
|
||||
(self.quant_config.group_size if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
297
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
297
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
|
||||
scalar_types.uint4b8, scalar_types.uint8b128
|
||||
]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if quant_type is None or \
|
||||
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
is_marlin_24_format = (
|
||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_marlin_24")
|
||||
|
||||
if is_marlin_24_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. "
|
||||
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin24 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||
if (self.quant_config.group_size != -1 and
|
||||
input_size_per_partition % self.quant_config.group_size != 0):
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}.")
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Meta
|
||||
meta = PackedvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
device="cuda",
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||
input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("B_24", qweight)
|
||||
layer.register_parameter("B_meta", meta)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||
workspace,
|
||||
self.quant_config.quant_type,
|
||||
size_m, size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
333
vllm/model_executor/layers/quantization/hqq_marlin.py
Normal file
333
vllm/model_executor/layers/quantization/hqq_marlin.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HQQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for HQQ Marlin"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
"currently 64.")
|
||||
assert weight_bits == 4, ("The only supported HQQ quantization "
|
||||
"bitsize is currently 4.")
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.skip_modules = skip_modules
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
|
||||
wq_params = (config["quant_config"]["weight_quant_params"])
|
||||
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
||||
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
||||
skip_modules = config["skip_modules"]
|
||||
return cls(weight_bits, group_size, skip_modules)
|
||||
|
||||
def is_layer_skipped(self, prefix: str) -> bool:
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
# Check if any of the skip modules exactly matches any component
|
||||
return self.skip_modules is not None and any(
|
||||
module_name in components for module_name in self.skip_modules)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return HQQMarlinMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
# Empty HQQ parameter, will be ignored during loading
|
||||
class HQQEmptyParameter(BasevLLMParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
pass
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
raise ValueError("No loader provided for HQQ parameter!")
|
||||
|
||||
|
||||
# HQQ packing creates issues with sharding - therefore, prior to loading, we
|
||||
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
|
||||
class HQQweightParameter(PackedvLLMParameter):
|
||||
|
||||
# unpack function from https://github.com/mobiusml/hqq
|
||||
def unpack_4bit_u8(self,
|
||||
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
|
||||
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
|
||||
|
||||
dtype = torch.uint8
|
||||
step = W_q.shape[0]
|
||||
tmp = torch.empty([2 * step, W_q.shape[1]],
|
||||
dtype=dtype,
|
||||
device=W_q.device)
|
||||
tmp[:step] = (W_q & 0b11110000) >> 4
|
||||
tmp[step:] = W_q & 0b00001111
|
||||
return tmp
|
||||
|
||||
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
|
||||
**kwargs):
|
||||
super().__init__(packed_factor, packed_dim, None, **kwargs)
|
||||
self.weight_bits = weight_bits
|
||||
self.input_shape = self.shape[self.input_dim] * self.packed_factor
|
||||
self.output_shape = self.shape[self.output_dim]
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(self.output_shape,
|
||||
-1).transpose(1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
||||
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
||||
1, 0)
|
||||
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
||||
loaded_weight.shape[0],
|
||||
loaded_weight.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
|
||||
# GPTQ shape (transposed - we transpose them too when processing weights).
|
||||
class HQQZeroScaleParameter(GroupQuantScaleParameter):
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_merged_column_weight(loaded_weight, **kwargs)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
|
||||
super().load_row_parallel_weight(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
||||
super().load_qkv_weight(loaded_weight, **kwargs)
|
||||
|
||||
|
||||
class HQQMarlinMethod(LinearMethodBase):
|
||||
"""Linear method for HQQ Marlin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: HQQMarlinConfig,
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
|
||||
|
||||
self.scales_and_zp_size = (input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
qweight = HQQweightParameter(
|
||||
data=torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_bits=self.quant_config.weight_bits,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
zeros = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = HQQZeroScaleParameter(data=torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("W_q", qweight)
|
||||
layer.register_parameter("zero", zeros)
|
||||
layer.register_parameter("scale", scales)
|
||||
|
||||
# Ignore extra parameters in the HQQ model.
|
||||
# To be added as needed.
|
||||
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
|
||||
"encoded_state_dict", "group_size", "nbits",
|
||||
"offload_meta", "optimize", "packing",
|
||||
"quant_scale", "quant_zero", "round_zero",
|
||||
"shape", "stores_quant_config",
|
||||
"unpack_view_dtype", "view_as_float")
|
||||
for name in ignore_parameters:
|
||||
layer.register_parameter(
|
||||
name,
|
||||
HQQEmptyParameter(data=torch.empty(0),
|
||||
weight_loader=weight_loader))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
dev = layer.W_q.device
|
||||
|
||||
# Repack to Marlin
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(
|
||||
layer.W_q,
|
||||
sort_indices,
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.weight_bits,
|
||||
).to(dev)
|
||||
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
|
||||
self.input_size_per_partition,
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.group_size).to(dev)
|
||||
|
||||
layer.g_idx = marlin_make_empty_g_idx(dev)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
layer.marlin_qweight = marlin_w_q
|
||||
layer.marlin_zeros = marlin_zp
|
||||
layer.marlin_scales = marlin_s
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
workspace = MarlinWorkspace(self.output_size_per_partition,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
scales = layer.marlin_scales
|
||||
zeros = layer.marlin_zeros
|
||||
orig_type = x.dtype
|
||||
|
||||
if orig_type != torch.float16:
|
||||
x = x.to(torch.float16)
|
||||
scales = scales.to(torch.float16)
|
||||
zeros = zeros.to(torch.float16)
|
||||
|
||||
marlin_out = ops.gptq_marlin_gemm(
|
||||
x,
|
||||
None,
|
||||
layer.marlin_qweight,
|
||||
bias,
|
||||
scales,
|
||||
None,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
scalar_types.uint4,
|
||||
x.shape[0],
|
||||
self.output_size_per_partition,
|
||||
self.input_size_per_partition,
|
||||
True, # is_k_full
|
||||
False, # use atomic add
|
||||
True, # use 32-bit reduce
|
||||
True, # use float zp
|
||||
)
|
||||
|
||||
if orig_type != torch.float16:
|
||||
marlin_out = marlin_out.to(orig_type)
|
||||
|
||||
return marlin_out
|
||||
61
vllm/model_executor/layers/quantization/inc.py
Normal file
61
vllm/model_executor/layers/quantization/inc.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Intel Gaudi supports quantization of various modules and functions,
|
||||
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
|
||||
# During model loading,
|
||||
# INC will patch layers with quantization/dequantization operators.
|
||||
# Meanwhile, INC will convert original weight to target datatype
|
||||
# and loading to target device.
|
||||
# static scaling should be provided through Quant_CONFIG:
|
||||
# `QUANT_CONFIG` is an environment variable,
|
||||
# that points to the measurement or quantization JSON config file.
|
||||
# The measurement configuration file is used during the calibration procedure,
|
||||
# to collect measurements for a given model.
|
||||
# The quantization configuration is used during inference.
|
||||
# For more information, please refer to:
|
||||
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
|
||||
class INCConfig(QuantizationConfig):
|
||||
"""Config class for FP8 using Intel Neural Compressor."""
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "inc"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
|
||||
raise AssertionError
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise AssertionError
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
156
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
156
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
|
||||
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
|
||||
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
|
||||
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
|
||||
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
column_major_scales: bool = False):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
"""
|
||||
super().__init__()
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
assert not static, "Group quantization only supports dynamic mode"
|
||||
self.group_size = group_shape.col
|
||||
else:
|
||||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
||||
assert not static or group_shape == GroupShape.PER_TENSOR, \
|
||||
"Only per-tensor scales supported for static quantization."
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
dtype=_FP8_DTYPE)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
== GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1)
|
||||
return ops.scaled_fp8_quant(
|
||||
x,
|
||||
scale,
|
||||
num_token_padding=self.num_token_padding,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
return self._quantize_group_native(x)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
== GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1)
|
||||
|
||||
if scale is None:
|
||||
if self.group_shape == GroupShape.PER_TOKEN:
|
||||
x_max, _ = x.abs().max(dim=-1)
|
||||
x_max = x_max.unsqueeze(-1).to(torch.float32)
|
||||
if scale_ub is not None:
|
||||
x_max = x_max.clamp(max=scale_ub)
|
||||
else:
|
||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||
|
||||
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
out = x.to(torch.float32) * scale.reciprocal()
|
||||
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
# This currently generates an extra Triton kernel in compilation.
|
||||
# Fortunately, we don't use padding if compiling.
|
||||
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
|
||||
# in general.
|
||||
if self.num_token_padding is not None:
|
||||
padding = max(self.num_token_padding - out.size(0), 0)
|
||||
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
|
||||
|
||||
return out, scale
|
||||
|
||||
def _quantize_group_native(
|
||||
self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_shape = x.shape
|
||||
hidden_dim = x.shape[-1]
|
||||
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
|
||||
padded_dim = num_groups * self.group_size
|
||||
|
||||
if padded_dim != hidden_dim:
|
||||
padding = padded_dim - hidden_dim
|
||||
x = F.pad(x, (0, padding), mode='constant', value=0.0)
|
||||
|
||||
x_grouped = x.view(-1, num_groups, self.group_size)
|
||||
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
|
||||
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
x_scaled = x_grouped / scales
|
||||
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
x_quant = x_quant.view(-1, padded_dim)
|
||||
if padded_dim != hidden_dim:
|
||||
x_quant = x_quant[..., :hidden_dim]
|
||||
x_quant = x_quant.view(orig_shape)
|
||||
|
||||
scales = scales.squeeze(-1)
|
||||
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
|
||||
|
||||
if self.column_major_scales:
|
||||
scales = scales.transpose(-2, -1).contiguous()
|
||||
|
||||
return x_quant, scales
|
||||
415
vllm/model_executor/layers/quantization/ipex_quant.py
Normal file
415
vllm/model_executor/layers/quantization/ipex_quant.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MIN_IPEX_VERSION = "2.6.0"
|
||||
|
||||
|
||||
class IPEXConfig(QuantizationConfig):
|
||||
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
|
||||
including AWQ, GPTQ.
|
||||
"""
|
||||
|
||||
IPEX_QUANT_METHOD_MAP = {
|
||||
"awq": 1,
|
||||
"gptq": 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.method = method
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
if self.weight_bits not in [4]:
|
||||
raise ValueError(f"IPEX quantization supports weight bits [4], "
|
||||
f"but got {self.weight_bits}.")
|
||||
|
||||
if self.method not in ["awq", "gptq"]:
|
||||
raise ValueError(f"IPEX quantization supports [awq, gptq], "
|
||||
f"but got {self.method}.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"IPEXConfig(method={self.method},"
|
||||
f"weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
|
||||
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||
if method == "awq":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config,
|
||||
["q_group_size", "group_size"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(method, weight_bits, group_size, modules_to_not_convert,
|
||||
False, False)
|
||||
# otherwise for gptq
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
|
||||
return cls(method, weight_bits, group_size, [], desc_act,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
if quant_method in ["awq", "gptq"]:
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.method == "awq":
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
return IPEXAWQLinearMethod(self)
|
||||
if self.method == "gptq":
|
||||
return IPEXGPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class IPEXGPTQLinearMethod(GPTQLinearMethod):
|
||||
"""GPTQ linear method using IPEX for the CPU/XPU backend.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if version.parse(
|
||||
ipex.__version__) < version.parse(MIN_IPEX_VERSION):
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method.") from err
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK
|
||||
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
layer.ipex_output_size = layer.qweight.shape[-1]
|
||||
g_idx = layer.g_idx if self.quant_config.desc_act else None
|
||||
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
|
||||
IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
g_idx=g_idx,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
|
||||
)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||
|
||||
|
||||
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
"""AWQ linear method using IPEX for the CPU/XPU backend.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: IPEXConfig):
|
||||
self.quant_config = quant_config # type: ignore
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer=layer)
|
||||
|
||||
bias = layer.bias if not layer.skip_bias_add else None
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if version.parse(
|
||||
ipex.__version__) < version.parse(MIN_IPEX_VERSION):
|
||||
raise ImportError(
|
||||
"intel_extension_for_pytorch version is "
|
||||
"wrong. Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install "
|
||||
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
|
||||
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
|
||||
" to use IPEX-AWQ linear method.") from err
|
||||
|
||||
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||
# with better performance.
|
||||
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||
# The weight will be de-packed from INT4 to INT8.
|
||||
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
|
||||
|
||||
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||
weight_dtype=weight_dtype,
|
||||
lowp_mode=lowp_mode,
|
||||
act_quant_mode=act_quant_mode,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
layer.ipex_output_size = layer.qweight.size(
|
||||
1) * self.quant_config.pack_factor
|
||||
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
|
||||
IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
layer.qweight.size(0),
|
||||
layer.ipex_output_size,
|
||||
qconfig=qconfig,
|
||||
bias=bias,
|
||||
group_size=self.quant_config.group_size,
|
||||
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
|
||||
)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||
|
||||
|
||||
class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True,
|
||||
weight_scale, bias)
|
||||
return output
|
||||
|
||||
|
||||
class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# INPUT_SCALES
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device),
|
||||
requires_grad=False)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight.data[expert, :, :])
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight.data[expert, :, :])
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
import intel_extension_for_pytorch as ipex
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1_scale_inv=layer.w13_weight_scale,
|
||||
w2_scale_inv=layer.w2_weight_scale,
|
||||
a1_scale_inv=layer.w13_input_scale,
|
||||
a2_scale_inv=layer.w2_input_scale,
|
||||
use_prepack=True,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: Optional[torch.dtype] = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
if c.zero_points:
|
||||
assert w_zp_param_name is not None
|
||||
if c.has_g_idx:
|
||||
assert w_gidx_param_name is not None
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
|
||||
fn: Callable) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name,
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor] # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
||||
BitBLASLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||
ConchLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
|
||||
CutlassW4A8LinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||
Dynamic4bitLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
MacheteLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
|
||||
MarlinLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
|
||||
MPLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
Dynamic4bitLinearKernel,
|
||||
BitBLASLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
]
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get
|
||||
the compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
if (compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"WNA16 linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args['sm_count'] = sm_count
|
||||
gemm_args['sm_version'] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(old_weight_param.data,
|
||||
requires_grad=False)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous().view(
|
||||
dtype=torch.uint8)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data,
|
||||
requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = \
|
||||
ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None,
|
||||
c.zero_points)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args['sm_count'],
|
||||
sm_version=gemm_args['sm_version'],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
|
||||
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
|
||||
unpack_gptq_qweight, unpack_gptq_qzeros)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
BITBLAS_DTYPES: dict[torch.dtype, str] = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
bitblas_matmul: object = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
|
||||
w_gidx_param_name)
|
||||
|
||||
def repack_bitblas_from_gptq(
|
||||
self,
|
||||
b_q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
|
||||
|
||||
quant_config = self.quant_config
|
||||
# qweight in gptq old quant linear stored with
|
||||
# (outfeatures, infeatures), should be transposed.
|
||||
qweight = b_q_weight.T.contiguous().view(
|
||||
quant_config.torch_storage_dtype) # type: ignore[union-attr]
|
||||
intweight = unpack_gptq_qweight(
|
||||
qweight,
|
||||
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
|
||||
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
|
||||
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
|
||||
intweight.cpu()).cuda()
|
||||
# scales in gptq old quant linear stored with
|
||||
# (infeatures // group_size, outfeatures), should be transposed.
|
||||
scales = scales.T.contiguous()
|
||||
|
||||
if qzeros is None:
|
||||
return qweight, scales, None
|
||||
|
||||
# qzeros should be de-quantized to int zeros.
|
||||
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
|
||||
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
|
||||
zeros: Optional[torch.Tensor] = None
|
||||
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
|
||||
if zeros_mode == "original":
|
||||
zeros = intzeros.to(torch.float16).contiguous()
|
||||
elif zeros_mode == "rescale":
|
||||
assert zeros is not None, "zeros should not be None"
|
||||
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
|
||||
elif zeros_mode == "quantized":
|
||||
zeros = (
|
||||
torch.Tensor(
|
||||
general_compress(
|
||||
intzeros.T.contiguous().cpu().numpy(),
|
||||
weight_bits,
|
||||
)).to(qweight.device).
|
||||
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
|
||||
).contiguous())
|
||||
else:
|
||||
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
|
||||
|
||||
return qweight, scales, zeros
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if version.parse(bitblas.__version__) < version.parse(
|
||||
MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError:
|
||||
is_bitblas_installed = False
|
||||
|
||||
if not is_bitblas_installed:
|
||||
return False, "bitblas is not installed. Please install bitblas "\
|
||||
"by running `pip install bitblas>="\
|
||||
f"{MINIMUM_BITBLAS_VERSION}`"
|
||||
|
||||
quant_types = query_bitblas_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, (f"Quant type ({c.weight_type}) not supported by"
|
||||
f" BitBLAS, supported types are: {quant_types}")
|
||||
|
||||
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
|
||||
return False, (f"Group size ({c.group_size}) not supported by "
|
||||
"BitBLAS, supported group sizes are: "
|
||||
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
|
||||
|
||||
return check_bitblas_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
quant_config = self.quant_config
|
||||
|
||||
# Default names since bitblas requires empty parameters for these,
|
||||
# TODO: remove this requirement from bitblas (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "qzeros"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
raise NotImplementedError("Zero points not supported by BitBLAS")
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
|
||||
|
||||
# Repack weights
|
||||
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
|
||||
self.repack_bitblas_from_gptq(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None if quant_config.is_sym else # type: ignore[union-attr]
|
||||
layer.qzeros, # type: ignore[union-attr]
|
||||
))
|
||||
replace_parameter(layer, self.w_q_name, bitblas_qweight)
|
||||
replace_parameter(layer, self.w_s_name, bitblas_scales)
|
||||
if bitblas_qzeros is not None:
|
||||
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
|
||||
|
||||
def configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures: int,
|
||||
outfeatures: int,
|
||||
params_dtype: torch.dtype,
|
||||
bias: bool,
|
||||
) -> None:
|
||||
enable_tuning = self.ENABLE_TUNING
|
||||
layout = self.MATMUL_LAYOUT
|
||||
bits = self.quant_config.weight_bits # type: ignore[union-attr]
|
||||
self._configure_bitblas_matmul(
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
)
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
quant_config = self.quant_config
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = quant_config.group_size # type: ignore[union-attr]
|
||||
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
|
||||
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if quant_config.is_sym: # type: ignore[union-attr]
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
M=self.OPT_FEATURES,
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=bitblas_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=quant_config. # type: ignore[union-attr]
|
||||
storage_dtype, # type: ignore[union-attr]
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNING_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created without tuning. "
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} retrieved from cache."
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq_bitblas_linear(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_size_per_partition = self.config.partition_weight_shape[1]
|
||||
out_shape = x.shape[:-1] + (output_size_per_partition, )
|
||||
args = [x, layer.qweight, layer.scales]
|
||||
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
|
||||
args.append(layer.qzeros)
|
||||
output = self.bitblas_matmul(*args) # type: ignore[operator]
|
||||
return output.view(out_shape)
|
||||
|
||||
def apply_weights(self, layer, x, bias=None):
|
||||
NOT_IMPLEMENT_MESSAGE = (
|
||||
f"{self.__class__.__name__}.apply_weights is not implemented. "
|
||||
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
|
||||
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
|
||||
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
|
||||
scalar_types.uint8b128
|
||||
]
|
||||
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||
|
||||
|
||||
class ConchLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = f"Weight type ({c.weight_type}) not supported by "\
|
||||
"ConchLinearKernel, supported types are: " \
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
return False, error_msg
|
||||
|
||||
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||
error_msg = f"Group size ({c.group_size}) not supported by "\
|
||||
"ConchLinearKernel, supported group sizes are: " \
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
return False, error_msg
|
||||
|
||||
if find_spec("conch") is None:
|
||||
error_msg = "conch-triton-kernels is not installed, please "\
|
||||
"install it via `pip install conch-triton-kernels` "\
|
||||
"and try again!"
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
output = mixed_precision_gemm(
|
||||
x=x,
|
||||
w_q_packed=w_q.data,
|
||||
w_s=w_s.data,
|
||||
w_zp=w_zp.data if w_zp is not None else None,
|
||||
weight_size_bits=self.config.weight_type.size_bits,
|
||||
weight_bias=self.config.weight_type.bias,
|
||||
group_size=self.config.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False,
|
||||
group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 "\
|
||||
"(Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering not supported by CUTLASS W4A8"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"CUTLASS W4A8, only supported int4"
|
||||
|
||||
# TODO(czhu): support -1 (column-wise)
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return False, "K and N must be divisible by 128, got "\
|
||||
f"{c.partition_weight_shape}"
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return False, "Only bfloat16 output type currently supported"\
|
||||
f"got {c.out_type=}"
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
|
||||
# TODO(czhu): optimize speed/mem usage
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(
|
||||
x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if current_platform.get_cpu_architecture(
|
||||
) == CpuArchEnum.ARM and c.act_type not in [
|
||||
torch.float32,
|
||||
]:
|
||||
return False, "Dynamic4bitLinearKernel on Arm requires"\
|
||||
" Float32 activations"
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return False, f"PyTorch {torch.__version__} does not support"\
|
||||
" _dyn_quant_matmul_4bit. Install a newer version"
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4
|
||||
| packed_weight[::, ::2]).to(torch.uint8)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
|
||||
# Handle scaling factors for partitioned weights
|
||||
if block_size == c.partition_weight_shape[0]:
|
||||
scales = scales.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
layer.bias = layer.bias.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 bias
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed, scales, layer.bias, block_size,
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1])
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(w, requires_grad=False))
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class ExllamaLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||
# currently untested so not added to the list
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Exllama, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return False, "Output features must be a multiple of the pack " \
|
||||
"factor (32 / num_bits) so that we can correctly " \
|
||||
"pack the zero points"
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Exllama, supported types are: "\
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}"
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# if the type has a bias we have to create a zeros tensor that
|
||||
# contains the bias values repeated for each group (-1 due to
|
||||
# a bug in the original GPTQ checkpoint format leading to
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full((groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference")
|
||||
zeros = pack_quantized_values_into_int32(zeros,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
setattr(layer, self.w_zp_name,
|
||||
torch.nn.Parameter(zeros, requires_grad=False))
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
# Exllama wants the permutation array instead of the group
|
||||
# indices
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
|
||||
c.weight_type.size_bits)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
check_machete_supports_shape, query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "Machete requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Machete, supported types are: "\
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Machete, supported group sizes are: "\
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}"
|
||||
|
||||
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
|
||||
.to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if c.act_type in [torch.float16, torch.bfloat16] \
|
||||
and c.partition_weight_shape[0] % 8 == 0:
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(x_perm,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
w_s = getattr(layer, self.w_s_name).data
|
||||
# pre-apply scales to zero-points
|
||||
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if c.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
if c.zero_points:
|
||||
assert w_zp is not None
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types,
|
||||
unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by"\
|
||||
f" Marlin, supported types are: {quant_types}"
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Marlin, supported group sizes are: "\
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
marlin_zero_points(
|
||||
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1]),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
is_channelwise: bool
|
||||
is_static_input_scheme: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||
w_s_param_name: str, i_s_param_name: str,
|
||||
i_zp_param_name: str, azp_adj_param_name: str) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
self.i_s_name = i_s_param_name
|
||||
self.i_zp_name = i_zp_param_name
|
||||
self.azp_adj_name = azp_adj_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
Optional[torch.Tensor], # input_zp
|
||||
Optional[torch.Tensor], # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.i_s_name),
|
||||
getattr(layer, self.i_zp_name),
|
||||
getattr(layer, self.azp_adj_name),
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
to be implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||
.split(","):
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"ScaledMM linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||
)
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"currently supported on non-ROCm platform.")
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"installed on ROCm.")
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (
|
||||
envs.VLLM_ROCM_USE_AITER_LINEAR \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
):
|
||||
return (False, "AiterScaledMMLinearKernel is disabled. " +
|
||||
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
|
||||
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"AiterScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
`AiterScaledMMLinearKernel` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
Currently only support per-tensor-per-tensor GEMM
|
||||
and per-token-per-channel GEMM through AITER
|
||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
assert symmetric, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
out_dtype = x.dtype
|
||||
|
||||
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == w_q.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
|
||||
m = x_q.shape[0] # a
|
||||
n = w_q.shape[1] # b
|
||||
|
||||
per_tensor_scale_a = (x_s.numel() == 1)
|
||||
per_tensor_scale_b = (w_s.numel() == 1)
|
||||
per_token_scale_a = (x_s.numel() == m)
|
||||
per_channel_scale_b = (w_s.numel() == n)
|
||||
|
||||
# @TODO:
|
||||
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||
# if one of the scale is a per-channel-scale.
|
||||
# For now, it only supports:
|
||||
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||
# - per-token-per-channel a8w8 scaled GEMM
|
||||
assert ((per_tensor_scale_a and per_tensor_scale_b)
|
||||
or (per_token_scale_a and per_channel_scale_b)), (
|
||||
"Currently only support per-tensor-per-tensor GEMM " +
|
||||
" and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
|
||||
"does not support AITER block scaled GEMM.")
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
|
||||
bias, out_dtype)
|
||||
206
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
Normal file
206
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric
|
||||
and check_cpu_sgl_kernel(N, K, dtype)):
|
||||
self.linear_method = self._apply_weights_sgl
|
||||
self.process_weights_for_sgl(layer)
|
||||
else:
|
||||
self.linear_method = self._apply_weights_onednn
|
||||
self.process_weights_for_onednn(layer)
|
||||
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# oneDNN kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).round().to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||
# term for dynamic quantization. And s_b should be folded into the
|
||||
# term. Such as:
|
||||
# s_a * s_b * [(A - zp_a)B] + bias =
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric
|
||||
and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||
weight,
|
||||
getattr(layer, self.w_s_name),
|
||||
torch.get_default_dtype(),
|
||||
getattr(layer, self.i_s_name) is None,
|
||||
not self.config.input_symmetric,
|
||||
32,
|
||||
)
|
||||
# weight is prepacked and maintained by the dnnl_handler,
|
||||
# release the original weight
|
||||
setattr(layer, self.w_q_name, None)
|
||||
del weight
|
||||
|
||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(packed_weight, requires_grad=False))
|
||||
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
layer.register_parameter(
|
||||
"bias_fp32",
|
||||
torch.nn.Parameter(bias.float().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return self.linear_method(
|
||||
layer,
|
||||
x,
|
||||
bias,
|
||||
)
|
||||
|
||||
def _apply_weights_onednn(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
|
||||
x, i_s, i_zp, self.config.input_symmetric)
|
||||
|
||||
m = x.size(0)
|
||||
n = self.dnnl_handler.n
|
||||
out = torch.empty((m, n), dtype=x.dtype)
|
||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj,
|
||||
bias)
|
||||
|
||||
return out
|
||||
|
||||
def _apply_weights_sgl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
layer.bias_fp32 if bias is not None else None,
|
||||
x.dtype,
|
||||
True,
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = i_zp is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel requires Triton which is not " +
|
||||
"currently supported on CPU.")
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
104
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
104
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"TPU platform does have a concept of compute capability, "
|
||||
"this method should not be called.")
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
if c.is_static_input_scheme:
|
||||
return False, "ScaledMMXLA requires dynamic activation scales."
|
||||
|
||||
if not c.input_symmetric:
|
||||
return False, "ScaledMMXLA requires symmetric activation scales."
|
||||
|
||||
if not c.is_channelwise:
|
||||
return False, "ScaledMMXLA requires channelwise weight scales"
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# XLA kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
# [out_channel,] (different than cutlass_scaled_mm)
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
# Filter warning for cond usage in apply_weights. It is okay
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=
|
||||
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501
|
||||
)
|
||||
|
||||
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
return x
|
||||
|
||||
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
return x + bias
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
quantize_activation=True,
|
||||
)
|
||||
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
143
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
143
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
- dequantize k/v_cache entries before fetching them from the cache
|
||||
|
||||
:param quant_config: the appropriate QuantizationConfig
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuantizationConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Create "weight" (aka q_scale, k_scale and v_scale)
|
||||
for an attention layer.
|
||||
"""
|
||||
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
# Initialize P = softmax(QK^T) scales
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(
|
||||
v_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
|
||||
if layer.q_scale < 0.0:
|
||||
logger.warning_once(
|
||||
"Checkpoint does not provide a q scaling factor. "
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
layer._q_scale_float = k_scale
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if (k_scale == 1.0 and v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
logger.warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
|
||||
if layer.q_scale > 0.0:
|
||||
q_scale = layer.q_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
q_scale *= 2
|
||||
layer.calculate_kv_scales = False
|
||||
else:
|
||||
q_scale = 1.0
|
||||
if layer.prob_scale > 0.0:
|
||||
prob_scale = layer.prob_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
prob_scale *= 2
|
||||
else:
|
||||
prob_scale = 1.0
|
||||
|
||||
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
|
||||
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
|
||||
if not is_singleton_float(q_scale) or not is_singleton_float(
|
||||
prob_scale):
|
||||
raise ValueError("Only support per-tensor scaling factor"
|
||||
"for fp8-quantized Q/prob")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._q_scale_float = q_scale.item() if isinstance(
|
||||
q_scale, torch.Tensor) else q_scale
|
||||
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||
or prob_scale == 1.0):
|
||||
logger.warning_once(
|
||||
f"Using uncalibrated q_scale {q_scale} and/or prob_scale "
|
||||
f"{prob_scale} with fp8 attention. This may cause accuracy "
|
||||
"issues. Please make sure q/prob scaling factors are "
|
||||
"available in the fp8 checkpoint.")
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.prob_scale
|
||||
1596
vllm/model_executor/layers/quantization/modelopt.py
Normal file
1596
vllm/model_executor/layers/quantization/modelopt.py
Normal file
File diff suppressed because it is too large
Load Diff
484
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
484
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,484 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class MoeWNA16Config(QuantizationConfig):
|
||||
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||
|
||||
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.bit8_pack_factor = 8 // self.weight_bits
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.linear_quant_method = linear_quant_method
|
||||
self.full_config = full_config
|
||||
self.use_marlin = False
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
|
||||
full_config)
|
||||
elif self.linear_quant_method == "awq":
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
if device_capability < awq_min_capability:
|
||||
raise ValueError(
|
||||
"The quantization method moe_wna16 + awq is not supported "
|
||||
"for the current GPU. "
|
||||
f"Minimum capability: {awq_min_capability}. "
|
||||
f"Current capability: {device_capability}.")
|
||||
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
|
||||
full_config)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = []
|
||||
else:
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
if linear_quant_method == "gptq":
|
||||
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method == "awq":
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
return cls(linear_quant_method, weight_bits, group_size, has_zp,
|
||||
lm_head_quantized, modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
|
||||
gptq_compatible = quant_method == "gptq" and \
|
||||
not desc_act and num_bits in [4, 8]
|
||||
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
||||
device_capability >= awq_min_capability
|
||||
|
||||
return gptq_compatible or awq_compatible
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, LinearBase):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
if self.linear_quant_method == "gptq":
|
||||
if self.use_marlin:
|
||||
return GPTQMarlinConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
elif self.linear_quant_method == "awq":
|
||||
if self.use_marlin and check_marlin_supports_layer(
|
||||
layer, self.group_size):
|
||||
return AWQMarlinConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return MoeWNA16Method(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config,
|
||||
moe: "FusedMoEConfig") -> None:
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
self.moe = layer
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
|
||||
# make intermediate_size and hidden_size divisible by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `MoeWNA16Method` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False):
|
||||
if "g_idx" in weight_name:
|
||||
return False if return_success else None
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return False if return_success else None
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = loaded_weight.T.contiguous().view(
|
||||
torch.uint8)
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
return True if return_success else None
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
return True if return_success else None
|
||||
else:
|
||||
# Delegate to the original loader, passing return_success
|
||||
return weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=return_success)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
988
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
988
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
@@ -0,0 +1,988 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_can_support_mxfp4, _swizzle_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
||||
next_power_of_2, round_up)
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# enum for mxfp4 backend
|
||||
class Mxfp4Backend(Enum):
|
||||
NONE = 0
|
||||
|
||||
# FlashInfer Backend
|
||||
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
|
||||
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
|
||||
SM100_FI_MXFP4_BF16 = 3
|
||||
SM90_FI_MXFP4_BF16 = 4
|
||||
|
||||
# Marlin Backend
|
||||
MARLIN = 5
|
||||
|
||||
# Triton Backend
|
||||
TRITON = 6
|
||||
|
||||
|
||||
def get_mxfp4_backend():
|
||||
# Backend Selection
|
||||
if current_platform.is_cuda():
|
||||
if (current_platform.is_device_capability(90) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
|
||||
return Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
elif (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
|
||||
"for high concurrency throughput workloads consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
|
||||
"performance")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 BF16 backend for SM100, "
|
||||
"For faster performance on SM100, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
|
||||
"accuracy.")
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
elif ((current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90))
|
||||
and not has_flashinfer()):
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
if current_platform.get_device_capability(
|
||||
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
|
||||
"2.8.0"):
|
||||
logger.info_once("Using Marlin backend")
|
||||
return Mxfp4Backend.MARLIN
|
||||
else:
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.mxfp4_backend = get_mxfp4_backend()
|
||||
self.max_capture_size = get_current_vllm_config(
|
||||
).compilation_config.max_capture_size
|
||||
|
||||
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
|
||||
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
|
||||
"Please check your environment and try again.")
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
# FIXME (zyongye): ship after torch and safetensors support mxfp4
|
||||
# is_torch_mxfp4_available = (
|
||||
# hasattr(torch, "float4_e2m1fn_x2") and
|
||||
# hasattr(torch, "float8_e8m0fnu"))
|
||||
# if is_torch_mxfp4_available:
|
||||
# weight_dtype = torch.float4_e2m1fn_x2
|
||||
# scale_dtype = torch.float8_e8m0fnu
|
||||
|
||||
mxfp4_block = 32
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
# The moe marlin kernel requires that for each linear
|
||||
# n % 256 == 0 and k % 128 == 0.
|
||||
# In gate_up_proj:
|
||||
# n = 2 * intermediate_size_per_partition_after_pad
|
||||
# k = hidden_size
|
||||
# In down_proj
|
||||
# n = hidden_size
|
||||
# k = intermediate_size_per_partition_after_pad
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer.fp4_quantization import (
|
||||
nvfp4_block_scale_interleave)
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w2_permute_indices)
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Do not interleave as the checkpoint is already interleaved
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w13_weight.device)].contiguous())
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w13_weight_scale.device)].contiguous()))
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
||||
-1,
|
||||
1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w2_weight.device)].contiguous())
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w2_weight_scale.device)].contiguous()))
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
||||
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = torch.stack(
|
||||
gemm1_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, 2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size).view(
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, self.hidden_size, self.intermediate_size //
|
||||
sf_block_size).view(torch.float8_e4m3fn)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
# Common shape assertions
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
# De-interleave and swap for w13 weight, bias, and scales
|
||||
w13_w = layer.w13_weight.data
|
||||
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
|
||||
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
|
||||
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
w13_b = layer.w13_bias.data.to(torch.float32)
|
||||
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
|
||||
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
|
||||
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
|
||||
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w13_s = layer.w13_weight_scale.data
|
||||
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
|
||||
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
|
||||
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import block_scale_interleave
|
||||
|
||||
orig_shape = w13_scale_swapped.shape
|
||||
w13_scale_interleaved = block_scale_interleave(
|
||||
w13_scale_swapped.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
w2_s = layer.w2_weight_scale.data
|
||||
orig_shape = w2_s.shape
|
||||
w2_scale_interleaved = block_scale_interleave(
|
||||
w2_s.view(torch.uint8)).reshape(orig_shape)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_scale_interleaved,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_scale_interleaved,
|
||||
requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
|
||||
def _interleave_mxfp4_cutlass_sm90(w):
|
||||
w_shape = w.shape
|
||||
w_interleaved = w.reshape(w_shape[0], w_shape[1],
|
||||
(w_shape[2] // 4), 4)
|
||||
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
|
||||
w_interleaved = w_interleaved.reshape(
|
||||
w_shape[0], w_shape[2] // 4, w_shape[1] * 4)
|
||||
return w_interleaved
|
||||
|
||||
w31_scales = w13_scale_swapped.to(torch.uint8).view(
|
||||
torch.uint8)
|
||||
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w31_scales)
|
||||
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
|
||||
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
|
||||
w2_scales)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w],
|
||||
dim=1),
|
||||
requires_grad=False)
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias_swapped,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w31_scales_interleaved, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales_interleaved, requires_grad=False)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||
# (stored in self.fused_experts) to determine if the MoE has a
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = (self.moe.use_pplx_kernels
|
||||
or self.moe.use_deepep_ll_kernels)
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return None
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
mk.FusedMoEActivationFormat.BatchedExperts):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
# TODO(bnell): part of quant_config
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||
**kwargs)
|
||||
else:
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
w13_weight = (self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None else layer.w13_weight)
|
||||
w2_weight = (self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None else layer.w2_weight)
|
||||
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13_weight,
|
||||
w2=w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
apply_router_weight_on_input, scoring_func, activation,
|
||||
expert_load_view, logical_to_physical_map,
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
|
||||
from flashinfer import trtllm_fp4_block_scale_moe
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
|
||||
from flashinfer import mxfp8_quantize
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*x.shape[:-1], -1)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
global_num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts,
|
||||
device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(
|
||||
torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(
|
||||
torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward)
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
306
vllm/model_executor/layers/quantization/petit.py
Normal file
306
vllm/model_executor/layers/quantization/petit.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.petit_utils import (
|
||||
apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit,
|
||||
verify_petit_nvfp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Configuration class to support the NVFP4 quantized model
|
||||
# generated by the ModelOpt quantization tool
|
||||
class PetitNvFp4Config(QuantizationConfig):
|
||||
"""Config class for Petit FP4."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool = False,
|
||||
kv_cache_quant_algo: Optional[str] = None,
|
||||
group_size: Optional[int] = None,
|
||||
exclude_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
self._check_hardware_support()
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning("Detected nvfp4 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change.")
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
def _check_hardware_support(self) -> None:
|
||||
"""
|
||||
Verifies that the current hardware is supported by the Petit backend.
|
||||
This backend is specifically designed for AMD GPUs and is not
|
||||
supported on the CUDA platform.
|
||||
"""
|
||||
# This check ensures the code is NOT running on an NVIDIA GPU.
|
||||
if current_platform.is_cuda():
|
||||
raise ValueError(
|
||||
"The 'petit' quantization backend is designed for AMD GPUs "
|
||||
"and is not supported on the CUDA platform. For NVIDIA GPUs, "
|
||||
"please use a different quantization method such as FP8, AWQ, "
|
||||
"or GPTQ.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "petit_nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Petit supports the gfx90a and gfx942 GPUs
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
|
||||
qc = cls.get_from_keys(config, ["quantization"])
|
||||
|
||||
quant_method_raw = qc.get("quant_algo")
|
||||
if not isinstance(quant_method_raw, str) or not quant_method_raw:
|
||||
raise ValueError(
|
||||
"Missing or invalid 'quant_algo' in quantization config.")
|
||||
quant_method = quant_method_raw.upper()
|
||||
|
||||
group_size_raw = qc.get("group_size")
|
||||
if not isinstance(group_size_raw, int):
|
||||
raise ValueError(
|
||||
"Missing or invalid 'group_size' (int) in hf_quant_config.json."
|
||||
)
|
||||
group_size = group_size_raw
|
||||
|
||||
verify_petit_nvfp4_supported(quant_method, group_size)
|
||||
|
||||
kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
|
||||
if not isinstance(kv_cache_quant_algo_raw, str):
|
||||
raise ValueError(
|
||||
"'kv_cache_quant_algo' must be a string if provided.")
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
|
||||
exclude_raw = qc.get("exclude_modules", [])
|
||||
if exclude_raw is None:
|
||||
exclude_modules: list[str] = []
|
||||
elif isinstance(exclude_raw, list) and all(
|
||||
isinstance(x, str) for x in exclude_raw):
|
||||
exclude_modules = exclude_raw
|
||||
else:
|
||||
raise ValueError(
|
||||
"'exclude_modules' must be a list[str] (or omitted).")
|
||||
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
|
||||
kv_cache_quant_algo=kv_cache_quant_algo,
|
||||
group_size=group_size,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_rocm():
|
||||
return None
|
||||
|
||||
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
|
||||
return cls.get_name() # "petit_nvfp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
|
||||
qc = quant_config.get("quantization", quant_config)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
return algo == "NVFP4"
|
||||
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
|
||||
prefix, exclude):
|
||||
return UnquantizedLinearMethod()
|
||||
return PetitNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return PetitFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def require_group_size(self) -> int:
|
||||
if self.group_size is None:
|
||||
logger.warning("group_size not set; defaulting to 16 for NVFP4.")
|
||||
return 16
|
||||
return self.group_size
|
||||
|
||||
def require_kv_cache_quant_algo(self) -> str:
|
||||
return self.kv_cache_quant_algo or "auto"
|
||||
|
||||
def require_exclude_modules(self) -> list[str]:
|
||||
return list(self.exclude_modules or [])
|
||||
|
||||
|
||||
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class PetitNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
|
||||
|Tensor Name | datatype | shape |
|
||||
|----------------------------------------------------|
|
||||
|input_scale | torch.float32 | scalar |
|
||||
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
||||
|weight_scale | FP8-E4M3 | [X, Y] |
|
||||
|weight_scale_2 | torch.float32 | scalar |
|
||||
|
||||
The weights are quantized per block of 16 elements.
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
if input_size_per_partition % 16 != 0:
|
||||
raise ValueError("Unsupported model when in features size is "
|
||||
"not multiple of 16")
|
||||
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
else params_dtype)
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
# 2 fp4 data is packed in one uint8 in the input dimension
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
group_size = self.quant_config.require_group_size()
|
||||
weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // group_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
prepare_nvfp4_layer_for_petit(layer)
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_petit_nvfp4_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
129
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
129
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on ROCm.")
|
||||
|
||||
if not current_platform.has_device_capability(94):
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
|
||||
)
|
||||
if activation_scheme == "static":
|
||||
raise ValueError(
|
||||
"ptpc_fp8 as of now only support dynamic quantization.")
|
||||
|
||||
super().__init__(is_checkpoint_fp8_serialized=False,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return PTPCFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
|
||||
Only supports loading quantized BF16 model checkpoints with dynamic
|
||||
activation scaling. To load FP16 model checkpoints, user must specify
|
||||
to convert the FP16 model weight loading into BF16.
|
||||
The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support float8_e4m3fnuz data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
assert current_platform.is_rocm(), \
|
||||
"PTPCFp8LinearMethod is only supported on ROCm."
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
assert layer.weight.data.dtype == torch.bfloat16, \
|
||||
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
|
||||
# Quantize the weights.
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(
|
||||
layer.weight, scale=None, use_per_token_if_dynamic=True)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(
|
||||
qweight.t(), requires_grad=False) # Pretranspose the weight
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=None,
|
||||
bias=bias)
|
||||
432
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
432
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self,
|
||||
module=layer,
|
||||
layer_name=prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
raise ValueError("The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, q_configs[0])
|
||||
for q_config in q_configs):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different.")
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError(
|
||||
"The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
return cls(quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("dtype") == "fp8_e4m3")
|
||||
is_static_weight = not weight_quant.get("is_dynamic")
|
||||
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"])
|
||||
|
||||
if not (is_fp8_dtype and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
||||
and input_quant.get("dtype") == "int8")
|
||||
|
||||
is_tensor = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"]
|
||||
and input_quant.get("qscheme") == "per_tensor")
|
||||
|
||||
is_static = (not weight_quant.get("is_dynamic")
|
||||
and not input_quant.get("is_dynamic"))
|
||||
|
||||
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
"weight_quant or input_quant not set")
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4" or input_quant.get(
|
||||
"dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
|
||||
"qscheme") != "per_group":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get(
|
||||
"group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Activations need to use dynamic quantization.
|
||||
if input_quant.get("is_dynamic") is False:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not activation dynamic")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
|
||||
"scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0])
|
||||
for q_config in shard_configs):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return QuarkW8A8Fp8(weight_config, input_config)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"))
|
||||
elif self._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError("No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
f"Input config: {input_config}")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkScheme":
|
||||
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: QuarkConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_config: the quark kv cache scheme
|
||||
"""
|
||||
if kv_cache_config is None:
|
||||
return
|
||||
|
||||
dtype = kv_cache_config.get("dtype")
|
||||
if dtype != "fp8_e4m3":
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
f"dtype=fp8_e4m3, however received {dtype}")
|
||||
|
||||
qscheme = kv_cache_config.get("qscheme")
|
||||
if qscheme != "per_tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for quark KV cache. "
|
||||
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|
||||
561
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
561
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
@@ -0,0 +1,561 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkMoEMethod":
|
||||
layer_quant_config = quant_config._find_matched_config(
|
||||
layer_name, module)
|
||||
|
||||
if (layer_quant_config.get("output_tensors")
|
||||
or layer_quant_config.get("bias")):
|
||||
raise NotImplementedError("Currently, Quark models with "
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported")
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
elif quant_config._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
self.weight_qscheme = self.weight_quant.get("qscheme")
|
||||
self.input_qscheme = self.input_quant.get("qscheme")
|
||||
per_tensor = (self.weight_qscheme == "per_tensor"
|
||||
and self.input_qscheme == "per_tensor")
|
||||
per_channel = (self.weight_qscheme == "per_channel"
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||
if per_channel else GroupShape.PER_TENSOR
|
||||
if not (per_tensor or per_channel):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor and per-channel "
|
||||
"scales for weights and activations are supported. Found "
|
||||
f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
if self.static_input_scales and per_channel:
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||
"channelwise, dynamic per token quantization.")
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They are combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
# quark's scale is 1 dim.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# For per-tensor case, Fp8 moe kernel needs single weight scale
|
||||
# for w13 per expert. Use max then dequant and requant each expert.
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
# quark's scale is 1 dim.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False)
|
||||
w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
elif self.use_marlin:
|
||||
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
self.fused_experts_func = None
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map)
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert self.fused_experts_func is not None
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config)
|
||||
|
||||
|
||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
weight_qscheme = self.weight_quant.get("qscheme")
|
||||
input_qscheme = self.input_quant.get("qscheme")
|
||||
if not (weight_qscheme == "per_group"
|
||||
and input_qscheme == "per_group"):
|
||||
raise ValueError(
|
||||
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkW4A4MXFp4MoEMethod with static input scales is currently "
|
||||
"not implemented. Please open an issue.")
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
|
||||
params_dtype = torch.uint8
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
x_scales: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
if rocm_use_aiter_fp4_asm_gemm:
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty((M + 31) // 32 * 32,
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
gemm_a4w4(x_q,
|
||||
weight,
|
||||
x_s,
|
||||
weight_scale.view(x_s.dtype),
|
||||
y,
|
||||
bpreshuffle=True)
|
||||
return y[:M]
|
||||
else:
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
y = torch.empty(x_q.shape[0],
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
def gemm_with_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
x_scales: torch.Tensor = None,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], weight.shape[0]),
|
||||
dtype=out_dtype,
|
||||
device=x.device)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gemm_with_dynamic_quant",
|
||||
op_func=gemm_with_dynamic_quant,
|
||||
mutates_args=[],
|
||||
fake_impl=gemm_with_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None
|
||||
or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} requires AITER to be installed "
|
||||
"for non-emulation mode! Please refer to "
|
||||
"https://github.com/ROCm/aiter for installation details.")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(
|
||||
sm // 32, 2, 16, sn // 8, 2, 4, 1)
|
||||
weight_scale_shuffle = weight_scale_shuffle.permute(
|
||||
0, 3, 5, 2, 4, 1, 6).contiguous()
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
|
||||
requires_grad=False)
|
||||
|
||||
# shuffle weight
|
||||
weight_shuffle = layer.weight.data
|
||||
weight_shuffle = shuffle_weight(weight_shuffle,
|
||||
layout=(16, 16))
|
||||
layer.weight = torch.nn.Parameter(weight_shuffle,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(),
|
||||
requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
||||
x = quant_dequant_mxfp4(x)
|
||||
return F.linear(x, dq_w, bias)
|
||||
else:
|
||||
return torch.ops.vllm.gemm_with_dynamic_quant(
|
||||
x, layer.weight, layer.weight_scale,
|
||||
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
|
||||
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_config: dict[str, Any],
|
||||
input_config: Optional[dict[str, Any]]):
|
||||
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
self.is_static_input_scheme: bool = False
|
||||
self.input_qscheme: Optional[str] = None
|
||||
if input_config is not None:
|
||||
self.is_static_input_scheme = not cast(
|
||||
bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token = (not self.is_static_input_scheme
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||
if per_token else GroupShape.PER_TENSOR
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown quantization scheme {self.weight_qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.weight_qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.weight_qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True))
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, 'weight_zero_point')
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, 'input_zero_point')
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
105
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
105
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
466
vllm/model_executor/layers/quantization/rtn.py
Normal file
466
vllm/model_executor/layers/quantization/rtn.py
Normal file
@@ -0,0 +1,466 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""By default, use 8 bit as target precision, but it can be
|
||||
overridden by setting the RTN_NUM_BITS envvar
|
||||
"""
|
||||
NUM_BITS = os.getenv('RTN_NUM_BITS', "8")
|
||||
"""By default, use group size of 128 parameters, but it can be
|
||||
overridden by setting the RTN_GROUP_SIZE envvar
|
||||
"""
|
||||
GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128")
|
||||
|
||||
|
||||
class RTNConfig(QuantizationConfig):
|
||||
"""Config class for RTN.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int = int(NUM_BITS),
|
||||
group_size: int = int(GROUP_SIZE),
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
|
||||
if self.weight_bits != 4 and self.weight_bits != 8:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit or 8-bit weight quantization is "
|
||||
f"supported for RTN, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RTNConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "rtn"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return RTNLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return RTNMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
|
||||
class RTNTensor:
|
||||
"""A wrapper over Tensor that enables quantization on-the-fly by
|
||||
overloading the copy_ method.
|
||||
"""
|
||||
|
||||
def __init__(self, data: torch.Tensor, scale: torch.Tensor,
|
||||
quant_config: RTNConfig) -> None:
|
||||
self.data = data
|
||||
self.scale = scale
|
||||
self.quant_config = quant_config
|
||||
|
||||
def narrow(self, dim, start, length):
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
return RTNTensor(
|
||||
self.data.narrow(dim, start // factor, length // factor),
|
||||
self.scale.narrow(dim, start, length), self.quant_config)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return RTNTensor(self.data[key], self.scale[key], self.quant_config)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
shape = self.data.shape
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
batch_present = len(shape) == 3
|
||||
if batch_present:
|
||||
return torch.Size((shape[0], shape[1] * factor, shape[2]))
|
||||
else:
|
||||
return torch.Size((shape[0] * factor, shape[1]))
|
||||
|
||||
def copy_(self, loaded_weight: torch.Tensor) -> None:
|
||||
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size)
|
||||
|
||||
self.data.copy_(qweight)
|
||||
self.scale.data.copy_(weight_scale)
|
||||
|
||||
|
||||
class RTNParameter(Parameter):
|
||||
"""A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
|
||||
when its data is accessed. We need this wrapper for the data loading phase
|
||||
only, so we can intercept a weight copying function (torch.Tensor.copy_)
|
||||
and apply quantization on-the-fly.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
def __init__(self, data: torch.Tensor, scale: torch.Tensor,
|
||||
quant_config: RTNConfig) -> None:
|
||||
self.scale = scale
|
||||
self.quant_config = quant_config
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return RTNTensor(super().data, self.scale, self.quant_config)
|
||||
|
||||
|
||||
class RTNLinearMethod(LinearMethodBase):
|
||||
"""Linear method for RTN.
|
||||
|
||||
Args:
|
||||
quant_config: The RTN quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: RTNConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
num_groups_per_col = (input_size_per_partition //
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1 else 1)
|
||||
|
||||
scale = Parameter(
|
||||
torch.empty(output_size_per_partition,
|
||||
num_groups_per_col,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
|
||||
weight = RTNParameter(data=torch.empty(output_size_per_partition //
|
||||
factor,
|
||||
input_size_per_partition,
|
||||
dtype=torch.uint8),
|
||||
scale=scale,
|
||||
quant_config=self.quant_config)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
layer.register_parameter("scale", scale)
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
fix_weights(layer, "weight")
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.weight
|
||||
scale = layer.scale
|
||||
|
||||
weight = rtn_dequantize(qweight, scale)
|
||||
out = F.linear(x, weight)
|
||||
del weight
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RTNMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
num_groups_per_col = (hidden_size // self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1 else 1)
|
||||
w13_scale = Parameter(
|
||||
torch.empty(num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
num_groups_per_col,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w13_weight = RTNParameter(data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // factor,
|
||||
hidden_size,
|
||||
dtype=torch.uint8),
|
||||
scale=w13_scale,
|
||||
quant_config=self.quant_config)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
num_groups_per_col = (intermediate_size_per_partition //
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1 else 1)
|
||||
w2_scale = Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
num_groups_per_col,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
w2_weight = RTNParameter(data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size // factor,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.uint8),
|
||||
scale=w2_scale,
|
||||
quant_config=self.quant_config)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `RTNMoEMethod` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize a tensor using per-group static scaling factor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
num_bits: Target precision for the result (supported values are
|
||||
8 or 4).
|
||||
group_size: Quantization granularity.
|
||||
If equal to -1, each row in the input tensor is treated
|
||||
as one group.
|
||||
"""
|
||||
batch_present = len(tensor.shape) == 3
|
||||
if not batch_present:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
q_range = 2**num_bits
|
||||
num_groups = (tensor.shape[1] * tensor.shape[2] //
|
||||
group_size if group_size != -1 else tensor.shape[1])
|
||||
"""Calculate a scaling factor per input group.
|
||||
"""
|
||||
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
|
||||
input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
|
||||
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
|
||||
input_max_abs = torch.max(input_min.abs(), input_max.abs())
|
||||
scale = (input_max_abs * 2.0 / (q_range - 1))
|
||||
"""Scale each input group, round to the nearest integer, shift
|
||||
the range and truncate.
|
||||
"""
|
||||
scaled_input = input_flat / scale
|
||||
scaled_input = scaled_input.round()
|
||||
scaled_input += q_range // 2
|
||||
scaled_input = scaled_input.clamp(0, q_range - 1)
|
||||
|
||||
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
|
||||
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
|
||||
inputs_q = inputs_q.contiguous()
|
||||
|
||||
if num_bits == 4:
|
||||
"""Pack two 4-bit values into each byte.
|
||||
"""
|
||||
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf)
|
||||
inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2,
|
||||
tensor.shape[2])
|
||||
inputs_q = inputs_q.contiguous()
|
||||
|
||||
if not batch_present:
|
||||
inputs_q = inputs_q.squeeze(0)
|
||||
scale = scale.squeeze(0)
|
||||
|
||||
return inputs_q, scale
|
||||
|
||||
|
||||
def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize a tensor using per-group static scaling factors.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
scale: The tensor with per-group scale factors.
|
||||
"""
|
||||
batch_present = len(tensor.shape) == 3
|
||||
if not batch_present:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
scale = scale.unsqueeze(0)
|
||||
|
||||
num_groups = scale.size(1) * scale.size(2)
|
||||
batch, input_dim, output_dim = tensor.shape
|
||||
|
||||
num_bits = 8 if input_dim == scale.size(1) else 4
|
||||
q_range = 2**num_bits
|
||||
if num_bits == 4:
|
||||
input_dim *= 2
|
||||
|
||||
data = torch.empty((batch, input_dim, output_dim),
|
||||
dtype=scale.dtype,
|
||||
device=tensor.device)
|
||||
|
||||
if num_bits == 8:
|
||||
data.copy_(tensor)
|
||||
data -= q_range // 2
|
||||
else:
|
||||
"""Unpack two 4-bit values from each byte.
|
||||
"""
|
||||
tensor = tensor.reshape(batch, input_dim, output_dim // 2)
|
||||
for i in range(2):
|
||||
data[:, :, i::2] = ((tensor << 4 *
|
||||
(1 - i)) >> 4).to(torch.int8) - q_range // 2
|
||||
"""Scale each input group with its scaling factor.
|
||||
"""
|
||||
scale = scale.reshape(batch, num_groups, -1)
|
||||
data = data.reshape(batch, num_groups, -1)
|
||||
data = torch.mul(data, scale)
|
||||
|
||||
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
|
||||
if not batch_present:
|
||||
input_deq = input_deq.squeeze(0)
|
||||
|
||||
return input_deq
|
||||
|
||||
|
||||
def fix_weights(layer: torch.nn.Module,
|
||||
param_name: str,
|
||||
reshape: bool = False):
|
||||
"""torch.compile does not know how to deal with a Parameter subclass
|
||||
(aka RTNParameter). As we don't really need RTNParameters for the
|
||||
forward pass, we replace them with equivalent instances of Parameters.
|
||||
"""
|
||||
old_weight = getattr(layer, param_name)
|
||||
assert isinstance(old_weight, RTNParameter)
|
||||
data = old_weight.data.data
|
||||
|
||||
delattr(layer, param_name)
|
||||
|
||||
if reshape:
|
||||
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
|
||||
new_weight = Parameter(data=data, requires_grad=False)
|
||||
layer.register_parameter(param_name, new_weight)
|
||||
86
vllm/model_executor/layers/quantization/schema.py
Normal file
86
vllm/model_executor/layers/quantization/schema.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains the Pydantic schemas for various quantization-related
|
||||
parameters. When a relevant quantization technique is specified, these
|
||||
parameters are loaded in the form of a JSON alongside the model weights
|
||||
and augment the model with additional information needed for use of that
|
||||
technique. The format of this JSON should be specified by one or more
|
||||
schemas contained here.
|
||||
|
||||
For example, when the KV cache is quantized to FP8-E4M3 (currently only
|
||||
possible on ROCm), the model can be optionally augmented with KV cache
|
||||
scaling factors.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
|
||||
|
||||
class KVCacheQuantSchema(BaseModel):
|
||||
dtype: str
|
||||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: dict[int, dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
assert self.dtype == "float8_e4m3fn", (
|
||||
"Loaded scaling factors intended for KV cache dtype = "
|
||||
f"{self.dtype} rather than float8_e4m3fn!")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_size = context["tp_size"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
assert len(self.scaling_factor) == tp_size, (
|
||||
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
||||
f"but LLM engine is currently running with TP size {tp_size}.")
|
||||
for tp_rank, layer_maps in self.scaling_factor.items():
|
||||
assert len(layer_maps) == num_hidden_layers, (
|
||||
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
||||
f"Expected {num_hidden_layers} layers, got "
|
||||
f"{len(layer_maps)}.")
|
||||
for i in range(tp_size):
|
||||
assert i in self.scaling_factor, (
|
||||
f"KV cache scales map for TP rank {i} not found.")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
tp_rank = context["tp_rank"]
|
||||
num_hidden_layers = context["num_hidden_layers"]
|
||||
layer_scales_map = self.scaling_factor[tp_rank]
|
||||
for i in range(num_hidden_layers):
|
||||
assert i in layer_scales_map, (
|
||||
f"Could not find KV cache scales for layer {i} in "
|
||||
f"TP rank {tp_rank}.")
|
||||
return self
|
||||
|
||||
|
||||
class QuantParamSchema(BaseModel):
|
||||
# TODO: Generalize and extend with more fields
|
||||
# (e.g. weights/activations params) once functionality is enabled
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
model_type: Optional[str]
|
||||
kv_cache: KVCacheQuantSchema
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
||||
context = info.context
|
||||
if context:
|
||||
model_type = context.get("model_type", None)
|
||||
if model_type is not None:
|
||||
assert model_type == self.model_type, (
|
||||
f"Model type is {model_type} but loaded "
|
||||
f"scaling factors belonging to different "
|
||||
f"model type {self.model_type}!")
|
||||
return self
|
||||
214
vllm/model_executor/layers/quantization/torchao.py
Normal file
214
vllm/model_executor/layers/quantization/torchao.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
|
||||
"""
|
||||
Robust skipping logic:
|
||||
should_skip("model.model.layers.1.q_proj",
|
||||
["model.model.layers.1.q_proj"]) # True
|
||||
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
|
||||
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
|
||||
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
|
||||
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
|
||||
"""
|
||||
for s in skip_modules:
|
||||
if prefix == s:
|
||||
return True
|
||||
if f".{s}." in f".{prefix}.":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
def __init__(self,
|
||||
torchao_config,
|
||||
skip_modules: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
# TorchAO quantization relies on tensor subclasses. In order,
|
||||
# to enable proper caching this needs standalone compile
|
||||
if is_torch_equal_or_newer("2.8.0.dev"):
|
||||
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
|
||||
logger.info(
|
||||
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
|
||||
|
||||
# TODO: remove after the torch dependency is updated to 2.8
|
||||
if is_torch_equal_or_newer(
|
||||
"2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"):
|
||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||
"""
|
||||
super().__init__()
|
||||
self.torchao_config = torchao_config
|
||||
self.skip_modules = skip_modules or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchao>=0.10.0 via "
|
||||
"`pip install torchao>=0.10.0` to use torchao quantization."
|
||||
) from err
|
||||
|
||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||
assert hf_config is not None, "quant_type must be specified"
|
||||
assert len(hf_config) == 1 and "default" in hf_config, (
|
||||
"Expected only one key 'default' in quant_type dictionary")
|
||||
quant_type = hf_config["default"]
|
||||
ao_config = config_from_dict(quant_type)
|
||||
|
||||
# Adds skipped modules defined in "modules_to_not_convert"
|
||||
skip_modules = config.get("modules_to_not_convert", []) or []
|
||||
|
||||
# Adds skipped modules defined in "module_fqn_to_config"
|
||||
_data = quant_type.get("_data", {})
|
||||
if not isinstance(_data, dict):
|
||||
_data = {}
|
||||
|
||||
module_fqn = _data.get("module_fqn_to_config", {})
|
||||
if not isinstance(module_fqn, dict):
|
||||
module_fqn = {}
|
||||
|
||||
for layer, layer_cfg in module_fqn.items():
|
||||
if layer_cfg is None:
|
||||
skip_modules.append(layer)
|
||||
|
||||
return cls(ao_config, skip_modules)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
from torchao.quantization import ModuleFqnToConfig
|
||||
|
||||
if should_skip(prefix, self.skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
module_fqn = prefix
|
||||
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||
c = module_fqn_to_config.get(
|
||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||
if c is not None:
|
||||
current_torchao_config = TorchAOConfig(c, self.skip_modules)
|
||||
return TorchAOLinearMethod(current_torchao_config)
|
||||
else:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
return TorchAOLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def torchao_quantize_param_data(param: torch.Tensor,
|
||||
torchao_config: Any) -> torch.nn.Parameter:
|
||||
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
||||
|
||||
Args:
|
||||
param: weight parameter of the linear module
|
||||
torchao_config: type of quantization and their arguments we want to
|
||||
use to quantize the Tensor
|
||||
"""
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
|
||||
"""
|
||||
Avoid real weight allocation for faster load, since we will
|
||||
end up setting it to param.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
# linear can't be top level module since quantize_ is inplace
|
||||
# while some of our configs need to do module swap, and only non-top
|
||||
# level modules support module swap
|
||||
dummy_linear = torch.nn.Sequential(
|
||||
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
|
||||
|
||||
dummy_linear[0].weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
return dummy_linear[0].weight
|
||||
|
||||
|
||||
class TorchAOLinearMethod(LinearMethodBase):
|
||||
"""Linear method for torchao.
|
||||
|
||||
Args:
|
||||
quant_config: The torchao quantization config, a string that encodes
|
||||
the type of quantization and all relevant arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: TorchAOConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torchao_quantize_param_data(weight,
|
||||
self.quant_config.torchao_config)
|
||||
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
125
vllm/model_executor/layers/quantization/tpu_int8.py
Normal file
125
vllm/model_executor/layers/quantization/tpu_int8.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
ACTIVATION_SCHEMES = ["none", "dynamic"]
|
||||
|
||||
|
||||
class Int8TpuConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for TPU Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "none",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with TPU Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(self, layer: Module,
|
||||
prefix: str) -> Optional["TPUInt8LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return TPUInt8LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class TPUInt8LinearMethod(LinearMethodBase):
|
||||
"""Int8 Linear method for TPU Quant. """
|
||||
|
||||
def __init__(self, quant_config: Int8TpuConfig):
|
||||
self.quant_config = quant_config
|
||||
self.quantize_activation = False
|
||||
if self.quant_config.activation_scheme == 'dynamic':
|
||||
self.quantize_activation = True
|
||||
|
||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight_dtype = weight.dtype
|
||||
weight = weight.cpu().to(torch.float32)
|
||||
n_bit = 8
|
||||
eps = 1e-5
|
||||
max_int = 2**(n_bit - 1) - 1
|
||||
min_int = -(2**(n_bit - 1))
|
||||
max_val = weight.abs().amax(dim=-1, keepdim=True)
|
||||
max_val = max_val.clamp(min=eps)
|
||||
qscale = max_val / max_int
|
||||
qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
|
||||
max_int).to(torch.int8)
|
||||
qscale = qscale.squeeze().to(weight_dtype)
|
||||
return qweight, qscale
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
device = layer.weight.device
|
||||
qweight, qscale = self._quantize_weight(layer.weight)
|
||||
qweight = qweight.to(device)
|
||||
qscale = qscale.to(device)
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.scale = Parameter(qscale, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
try:
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torch_xla by following the instructions at "
|
||||
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
|
||||
"to run vLLM on TPU.") from err
|
||||
weight = layer.weight
|
||||
scale = layer.scale
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x, weight, scale, quantize_activation=self.quantize_activation)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .layer_utils import replace_parameter, update_tensor_inplace
|
||||
|
||||
__all__ = ['update_tensor_inplace', 'replace_parameter']
|
||||
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
|
||||
ALLSPARK_AMPERE_N_ALIGN = 16
|
||||
ALLSPARK_AMPERE_K_ALIGN = 16
|
||||
|
||||
|
||||
def check_allspark_supported_dtype_shape(input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
group_size: int,
|
||||
weight_dtype: ScalarType,
|
||||
act_dtype: torch.dtype):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
# For Ampere GPU
|
||||
if device_capability >= 80 and device_capability < 90:
|
||||
if group_size != -1:
|
||||
return False, \
|
||||
"For Ampere GPU, AllSpark does not support group_size "\
|
||||
f"= {group_size}. Only group_size = -1 are supported."
|
||||
|
||||
if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
|
||||
return False, "For Ampere GPU, AllSpark does not support "\
|
||||
f"quant type ({weight_dtype}). Only quant type "\
|
||||
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."
|
||||
|
||||
if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
|
||||
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
|
||||
return False, \
|
||||
"AllSpark needs input_size_per_partition % "\
|
||||
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
|
||||
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
|
||||
"for Ampere GPU optimized kernels."
|
||||
|
||||
if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
|
||||
return False, \
|
||||
"AllSpark only supports act_dtype = float16 or bfloat16,"\
|
||||
f"for Ampere GPU, but got act_dtype = {act_dtype}."
|
||||
else:
|
||||
return False, "AllSpark currently does not support "\
|
||||
f"device_capability = {device_capability}."
|
||||
|
||||
return True, None
|
||||
210
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
210
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
MINIMUM_BITBLAS_VERSION = "0.1.0"
|
||||
|
||||
BITBLAS_MIN_WEIGHT_SIZE_N = 16
|
||||
BITBLAS_MIN_WEIGHT_SIZE_K = 16
|
||||
GPTQ_BITBLAS_MAX_PARALLEL = 16
|
||||
|
||||
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# For dynamic shape code generation
|
||||
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
# If want to enable high performance for contiguous batching
|
||||
# Please use the following values
|
||||
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
|
||||
BITBLAS_SUPPORTED_SYM = [False, True]
|
||||
|
||||
|
||||
# Determines the supported quantization types for BitBLAS based on the
|
||||
# device's capability and whether zero-point (zp) is used.
|
||||
def query_bitblas_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
if device_capability < 70:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def _check_bitblas_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_bitblas_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"BitBLAS does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
# Finally, check if bitblas is installed
|
||||
try:
|
||||
import bitblas
|
||||
if version.parse(
|
||||
bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError:
|
||||
return False, "BitBLAS is not installed."
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def check_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
device_capability: Optional[int] = None) -> bool:
|
||||
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
|
||||
device_capability)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False) -> None:
|
||||
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def verify_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) -> None:
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
|
||||
raise ValueError(f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}."
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
|
||||
def check_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_bitblas_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
group_size)
|
||||
except ValueError as e:
|
||||
return False, e.__str__()
|
||||
return True, None
|
||||
|
||||
|
||||
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
|
||||
is_row_parallel: bool) -> bool:
|
||||
# Need to repeat scales on every rank if act_ordering or
|
||||
# channelwise and RowParallelLinear
|
||||
is_channelwise = group_size == -1
|
||||
return act_order or (is_channelwise and is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
|
||||
qzeros = qzeros.view(torch.int32)
|
||||
elems_per_int32 = 32 // bits
|
||||
unpacked_zeros = torch.zeros(
|
||||
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
|
||||
dtype=torch.int8,
|
||||
device=qzeros.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for col in range(unpacked_zeros.shape[1]):
|
||||
i = col % elems_per_int32
|
||||
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
|
||||
(bits * i)) & 0xF
|
||||
if not is_gptq_v2:
|
||||
return unpacked_zeros + 1
|
||||
return unpacked_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight(qweight, bits):
|
||||
qweight = qweight.view(torch.int8)
|
||||
elems_per_int8 = 8 // bits
|
||||
unpacked_weight = torch.zeros(
|
||||
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
|
||||
dtype=torch.int8,
|
||||
device=qweight.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
for col in range(unpacked_weight.shape[1]):
|
||||
i = col % elems_per_int8
|
||||
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
|
||||
(bits * i))
|
||||
|
||||
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"kpack": 1,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"num_warps": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user